Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
0184ccab28
|
|||
|
471b9b7629
|
@@ -1,726 +0,0 @@
|
||||
name: build-prerelease
|
||||
|
||||
# Builds CUDA-flavoured neuron binaries (and a single cortex binary),
|
||||
# packages each as a Fedora RPM, signs them, and publishes to the
|
||||
# `unstable` channel at rpm.lair.cafe.
|
||||
#
|
||||
# Change-aware: the `prepare` job diffs HEAD against the git sha
|
||||
# embedded in the most recently *published* unstable RPM (per package)
|
||||
# and skips builds whose inputs didn't change. Docs-only commits build
|
||||
# nothing; gateway-only commits skip the 3 CUDA builds (and, via
|
||||
# deploy.yml's own check-update gate, the neuron restarts + model
|
||||
# cold-loads). Diffing against the published sha — not the previous
|
||||
# push — means a failed run can never cause a change to be missed.
|
||||
#
|
||||
# Lint (fmt+clippy) and test run here as parallel jobs and gate
|
||||
# `publish`; ci.yml no longer runs on pushes to main (see its trigger
|
||||
# comment), so the two workflows stop competing for the same runners.
|
||||
#
|
||||
# The published packages are versioned as e.g.
|
||||
# helexa-neuron-blackwell-0.1.16-0.1.20260518T140530.gitabcdef0.fc43.x86_64
|
||||
# ^^^^^^^^^^^^^^^^^^ ^^^^^^^^
|
||||
# commit time (s) commit sha
|
||||
# so they sort BELOW the eventual 0.1.16-1 stable release, and so two
|
||||
# commits on the same day are still strictly ordered by their commit
|
||||
# timestamps (rather than by RPM-vercmp's alpha-vs-digit precedence
|
||||
# on the SHA fragment).
|
||||
|
||||
on:
|
||||
# Auto-build on every push to main so the unstable channel tracks
|
||||
# head without a manual dispatch step.
|
||||
push:
|
||||
branches: [main]
|
||||
# Manual dispatch still available to build from a non-main ref.
|
||||
# Dispatched runs skip change detection and build everything.
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
ref:
|
||||
description: "Git ref to build (branch / tag / commit). Defaults to the workflow's branch."
|
||||
required: false
|
||||
default: ""
|
||||
|
||||
# Coalesce same-ref pushes: a newer push cancels the older in-flight
|
||||
# run — the newest commit is the one we want on the fleet. The publish
|
||||
# job keeps its own `rpm-publish` group (cancel=false) so an in-flight
|
||||
# repo update is never interrupted. Runners are ephemeral (one VM per
|
||||
# job) so concurrent runs no longer race on a shared workspace; the
|
||||
# old shared `cortex-runner-pool` group with ci.yml is gone.
|
||||
concurrency:
|
||||
group: build-prerelease-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
CARGO_INCREMENTAL: "0"
|
||||
CARGO_TERM_COLOR: "always"
|
||||
|
||||
jobs:
|
||||
prepare:
|
||||
name: Resolve version stamps + change detection
|
||||
runs-on: rust
|
||||
outputs:
|
||||
version: ${{ steps.info.outputs.version }}
|
||||
release: ${{ steps.info.outputs.release }}
|
||||
short_sha: ${{ steps.info.outputs.short_sha }}
|
||||
commit_timestamp: ${{ steps.info.outputs.commit_timestamp }}
|
||||
build_cortex: ${{ steps.changes.outputs.build_cortex }}
|
||||
build_neuron: ${{ steps.changes.outputs.build_neuron }}
|
||||
build_bench: ${{ steps.changes.outputs.build_bench }}
|
||||
check_rust: ${{ steps.changes.outputs.check_rust }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
fetch-depth: 0
|
||||
|
||||
- id: info
|
||||
run: |
|
||||
set -eux
|
||||
VERSION=$(awk -F\" '/^version[[:space:]]*=/ { print $2; exit }' Cargo.toml)
|
||||
SHORT_SHA=$(git rev-parse --short=7 HEAD)
|
||||
# Second-precise commit timestamp gives the release stamp a
|
||||
# strictly monotonic numeric prefix. The earlier %Y%m%d-only
|
||||
# form let same-day builds be ordered by RPM's rpmvercmp
|
||||
# rules over the SHA, which is non-chronological — e.g.
|
||||
# "git602e8e1" sorts newer than "gitf9f5fa4" purely because
|
||||
# rpmvercmp ranks digit-prefixed segments above alpha ones.
|
||||
# The SHA stays only as a debug identifier; sort order is
|
||||
# decided entirely by the timestamp.
|
||||
COMMIT_TIMESTAMP=$(git log -1 --format=%cd --date=format:%Y%m%d%H%M%S HEAD)
|
||||
RELEASE="0.1.${COMMIT_TIMESTAMP}.git${SHORT_SHA}"
|
||||
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
|
||||
echo "release=${RELEASE}" >> "$GITHUB_OUTPUT"
|
||||
echo "short_sha=${SHORT_SHA}" >> "$GITHUB_OUTPUT"
|
||||
echo "commit_timestamp=${COMMIT_TIMESTAMP}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- id: changes
|
||||
run: |
|
||||
set -ux
|
||||
# Default: build everything. Detection only ever narrows
|
||||
# this, and any failure along the way (manifest unreachable,
|
||||
# unparsable, sha not in history after a force-push) leaves
|
||||
# the full build in place. Manual dispatches always build
|
||||
# everything — predictable when building odd refs.
|
||||
BUILD_CORTEX=true
|
||||
BUILD_NEURON=true
|
||||
BUILD_BENCH=true
|
||||
CHECK_RUST=true
|
||||
|
||||
if [ "${GITHUB_EVENT_NAME}" = "push" ]; then
|
||||
MANIFEST_URL="https://rpm.lair.cafe/fedora/43/x86_64/unstable/packages.json"
|
||||
if curl -fsS --max-time 20 -o /tmp/packages.json "$MANIFEST_URL"; then
|
||||
# Latest published sha per package, by buildTime.
|
||||
base_for() {
|
||||
python3 - "$1" <<'PY'
|
||||
import json, re, sys
|
||||
name = sys.argv[1]
|
||||
try:
|
||||
with open("/tmp/packages.json") as f:
|
||||
pkgs = json.load(f)["packages"]
|
||||
cands = [p for p in pkgs if p.get("name") == name]
|
||||
if cands:
|
||||
latest = max(cands, key=lambda p: p.get("buildTime", 0))
|
||||
m = re.search(r"git\.?([0-9a-f]{7,40})", latest.get("release", ""))
|
||||
if m:
|
||||
print(m.group(1))
|
||||
except Exception:
|
||||
pass
|
||||
PY
|
||||
}
|
||||
|
||||
# true if no usable base, else true iff the diff since
|
||||
# the published sha touches the given path pattern.
|
||||
decide() {
|
||||
local base="$1" pattern="$2"
|
||||
if [ -z "$base" ] \
|
||||
|| ! git cat-file -e "${base}^{commit}" 2>/dev/null \
|
||||
|| ! git merge-base --is-ancestor "$base" HEAD 2>/dev/null; then
|
||||
echo true; return
|
||||
fi
|
||||
if git diff --name-only "${base}..HEAD" | grep -qE "$pattern"; then
|
||||
echo true
|
||||
else
|
||||
echo false
|
||||
fi
|
||||
}
|
||||
|
||||
# cortex-core is shared by both binaries; Cargo.{toml,lock}
|
||||
# affect both; this workflow file affects both.
|
||||
NEURON_RE='^crates/neuron/|^crates/cortex-core/|^Cargo\.toml$|^Cargo\.lock$|^rpm/helexa-neuron-prerelease\.spec$|^data/neuron|^neuron\.example\.toml$|^\.gitea/workflows/build-prerelease\.yml$'
|
||||
CORTEX_RE='^crates/cortex-gateway/|^crates/cortex-cli/|^crates/cortex-core/|^Cargo\.toml$|^Cargo\.lock$|^rpm/cortex-prerelease\.spec$|^data/cortex|^cortex\.example\.toml$|^models\.example\.toml$|^\.gitea/workflows/build-prerelease\.yml$'
|
||||
BENCH_RE='^crates/helexa-bench/|^crates/cortex-core/|^Cargo\.toml$|^Cargo\.lock$|^rpm/helexa-bench-prerelease\.spec$|^data/helexa-bench|^helexa-bench\.example\.toml$|^\.gitea/workflows/build-prerelease\.yml$'
|
||||
# Any Rust change (incl. crates not packaged here, e.g.
|
||||
# helexa-acp) still needs lint+test on main.
|
||||
RUST_RE='\.rs$|^crates/|Cargo\.toml$|^Cargo\.lock$'
|
||||
|
||||
CORTEX_BASE=$(base_for cortex)
|
||||
NEURON_BASE=$(base_for helexa-neuron-blackwell)
|
||||
BENCH_BASE=$(base_for helexa-bench)
|
||||
BUILD_CORTEX=$(decide "$CORTEX_BASE" "$CORTEX_RE")
|
||||
BUILD_NEURON=$(decide "$NEURON_BASE" "$NEURON_RE")
|
||||
BUILD_BENCH=$(decide "$BENCH_BASE" "$BENCH_RE")
|
||||
if [ "$BUILD_CORTEX" = "true" ] || [ "$BUILD_NEURON" = "true" ] || [ "$BUILD_BENCH" = "true" ]; then
|
||||
CHECK_RUST=true
|
||||
else
|
||||
CHECK_RUST=$(decide "$CORTEX_BASE" "$RUST_RE")
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "build_cortex=${BUILD_CORTEX}" >> "$GITHUB_OUTPUT"
|
||||
echo "build_neuron=${BUILD_NEURON}" >> "$GITHUB_OUTPUT"
|
||||
echo "build_bench=${BUILD_BENCH}" >> "$GITHUB_OUTPUT"
|
||||
echo "check_rust=${CHECK_RUST}" >> "$GITHUB_OUTPUT"
|
||||
echo "### change detection: build_cortex=${BUILD_CORTEX} build_neuron=${BUILD_NEURON} build_bench=${BUILD_BENCH} check_rust=${CHECK_RUST}"
|
||||
|
||||
# fmt + clippy + test moved here from ci.yml for main pushes so the
|
||||
# two workflows stop queueing against each other (ci.yml's checks
|
||||
# used to delay build-cortex by ~12 minutes on the shared runner
|
||||
# pool). They run in parallel with the builds and gate `publish`,
|
||||
# not the builds themselves — a clippy warning still can't reach the
|
||||
# fleet, but it also doesn't serialize the pipeline.
|
||||
lint:
|
||||
name: Lint (fmt + clippy)
|
||||
needs: prepare
|
||||
if: needs.prepare.outputs.check_rust == 'true'
|
||||
runs-on: rust
|
||||
env:
|
||||
RUSTC_WRAPPER: sccache
|
||||
SCCACHE_BUCKET: sccache
|
||||
SCCACHE_ENDPOINT: http://caveman.kosherinata.internal:9000
|
||||
SCCACHE_REGION: auto
|
||||
SCCACHE_S3_USE_SSL: "false"
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
- run: cargo fmt --check --all
|
||||
# sccache failures come in two modes: transient races (a plain
|
||||
# retry clears them) and a wedged/dead server, where every
|
||||
# same-VM retry fails identically (sccache fatal error, ENOENT
|
||||
# on its own tmp files). Escalate accordingly: retry → restart
|
||||
# the server → final attempt uncached. A sick cache costs build
|
||||
# time, never the run.
|
||||
- name: Clippy (with sccache escalation)
|
||||
run: |
|
||||
for attempt in 1 2 3; do
|
||||
echo "::group::clippy attempt ${attempt}"
|
||||
if [ "${attempt}" -eq 3 ]; then
|
||||
echo "final attempt: building without sccache"
|
||||
export RUSTC_WRAPPER=""
|
||||
fi
|
||||
if cargo clippy --workspace -- -D warnings; then
|
||||
echo "::endgroup::"
|
||||
exit 0
|
||||
fi
|
||||
echo "::endgroup::"
|
||||
echo "clippy failed on attempt ${attempt}"
|
||||
if [ "${attempt}" -eq 1 ]; then
|
||||
sccache --stop-server || true
|
||||
sccache --start-server || true
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
echo "clippy failed after 3 attempts"
|
||||
exit 1
|
||||
- run: sccache --show-stats || true
|
||||
|
||||
test:
|
||||
name: Test
|
||||
needs: prepare
|
||||
if: needs.prepare.outputs.check_rust == 'true'
|
||||
runs-on: rust
|
||||
env:
|
||||
RUSTC_WRAPPER: sccache
|
||||
SCCACHE_BUCKET: sccache
|
||||
SCCACHE_ENDPOINT: http://caveman.kosherinata.internal:9000
|
||||
SCCACHE_REGION: auto
|
||||
SCCACHE_S3_USE_SSL: "false"
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
# See the lint job for the escalation rationale.
|
||||
- name: Test (with sccache escalation)
|
||||
run: |
|
||||
for attempt in 1 2 3; do
|
||||
echo "::group::test attempt ${attempt}"
|
||||
if [ "${attempt}" -eq 3 ]; then
|
||||
echo "final attempt: building without sccache"
|
||||
export RUSTC_WRAPPER=""
|
||||
fi
|
||||
if cargo test --workspace; then
|
||||
echo "::endgroup::"
|
||||
exit 0
|
||||
fi
|
||||
echo "::endgroup::"
|
||||
echo "test failed on attempt ${attempt}"
|
||||
if [ "${attempt}" -eq 1 ]; then
|
||||
sccache --stop-server || true
|
||||
sccache --start-server || true
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
echo "test failed after 3 attempts"
|
||||
exit 1
|
||||
- run: sccache --show-stats || true
|
||||
|
||||
build-cortex:
|
||||
name: Build cortex binary
|
||||
needs: prepare
|
||||
if: needs.prepare.outputs.build_cortex == 'true'
|
||||
# runner-rust image already provides rust/cargo/clippy/rustfmt via
|
||||
# dnf — no rustup install step needed.
|
||||
runs-on: rust
|
||||
env:
|
||||
RUSTC_WRAPPER: sccache
|
||||
SCCACHE_BUCKET: sccache
|
||||
SCCACHE_ENDPOINT: http://caveman.kosherinata.internal:9000
|
||||
SCCACHE_REGION: auto
|
||||
SCCACHE_S3_USE_SSL: "false"
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
# Escalation mirrors the lint/test jobs: retry → restart the
|
||||
# sccache server → final attempt uncached. A sick cache costs
|
||||
# build time, never the run.
|
||||
- name: Build cortex (release, with sccache escalation)
|
||||
run: |
|
||||
for attempt in 1 2 3; do
|
||||
echo "::group::build attempt ${attempt}"
|
||||
if [ "${attempt}" -eq 3 ]; then
|
||||
echo "final attempt: building without sccache"
|
||||
export RUSTC_WRAPPER=""
|
||||
fi
|
||||
if cargo build --release -p cortex-cli; then
|
||||
echo "::endgroup::"
|
||||
sccache --show-stats || true
|
||||
exit 0
|
||||
fi
|
||||
echo "::endgroup::"
|
||||
echo "build failed on attempt ${attempt}"
|
||||
if [ "${attempt}" -eq 1 ]; then
|
||||
sccache --stop-server || true
|
||||
sccache --start-server || true
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
echo "build failed after 3 attempts"
|
||||
exit 1
|
||||
|
||||
- name: Stage binary
|
||||
run: |
|
||||
mkdir --parents artifacts
|
||||
cp target/release/cortex artifacts/cortex
|
||||
./artifacts/cortex --version || true
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: cortex-fc43
|
||||
path: artifacts/cortex
|
||||
retention-days: 1
|
||||
|
||||
build-bench:
|
||||
name: Build helexa-bench binary
|
||||
needs: prepare
|
||||
if: needs.prepare.outputs.build_bench == 'true'
|
||||
# Pure-Rust, non-CUDA binary — same runner as cortex.
|
||||
runs-on: rust
|
||||
env:
|
||||
RUSTC_WRAPPER: sccache
|
||||
SCCACHE_BUCKET: sccache
|
||||
SCCACHE_ENDPOINT: http://caveman.kosherinata.internal:9000
|
||||
SCCACHE_REGION: auto
|
||||
SCCACHE_S3_USE_SSL: "false"
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
- name: Build helexa-bench (release, with sccache escalation)
|
||||
run: |
|
||||
# Stamp the SHA helexa-bench records as bench_sha against every
|
||||
# run (option_env! in sweep.rs reads it at compile time).
|
||||
export HELEXA_BUILD_SHA="$(git rev-parse HEAD)"
|
||||
for attempt in 1 2 3; do
|
||||
echo "::group::build attempt ${attempt}"
|
||||
if [ "${attempt}" -eq 3 ]; then
|
||||
echo "final attempt: building without sccache"
|
||||
export RUSTC_WRAPPER=""
|
||||
fi
|
||||
if cargo build --release -p helexa-bench; then
|
||||
echo "::endgroup::"
|
||||
sccache --show-stats || true
|
||||
exit 0
|
||||
fi
|
||||
echo "::endgroup::"
|
||||
echo "build failed on attempt ${attempt}"
|
||||
if [ "${attempt}" -eq 1 ]; then
|
||||
sccache --stop-server || true
|
||||
sccache --start-server || true
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
echo "build failed after 3 attempts"
|
||||
exit 1
|
||||
|
||||
- name: Stage binary
|
||||
run: |
|
||||
mkdir --parents artifacts
|
||||
cp target/release/helexa-bench artifacts/helexa-bench
|
||||
./artifacts/helexa-bench --version || true
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: bench-fc43
|
||||
path: artifacts/helexa-bench
|
||||
retention-days: 1
|
||||
|
||||
build-neuron:
|
||||
name: Build neuron-${{ matrix.flavour }}
|
||||
needs: prepare
|
||||
if: needs.prepare.outputs.build_neuron == 'true'
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- flavour: ampere
|
||||
compute_cap: "86"
|
||||
runner: cuda-13.0
|
||||
cuda_home: /usr/local/cuda-13.0
|
||||
build_jobs: 8
|
||||
nvcc_threads: 4
|
||||
cargo_features: "cuda cudnn"
|
||||
- flavour: ada
|
||||
compute_cap: "89"
|
||||
runner: cuda-13.0
|
||||
cuda_home: /usr/local/cuda-13.0
|
||||
build_jobs: 8
|
||||
nvcc_threads: 4
|
||||
cargo_features: "cuda cudnn"
|
||||
- flavour: blackwell
|
||||
compute_cap: "120"
|
||||
runner: cuda-13.0
|
||||
cuda_home: /usr/local/cuda-13.0
|
||||
build_jobs: 8
|
||||
nvcc_threads: 4
|
||||
cargo_features: "cuda cudnn"
|
||||
runs-on: ${{ matrix.runner }}
|
||||
env:
|
||||
SCCACHE_BUCKET: sccache
|
||||
SCCACHE_ENDPOINT: http://caveman.kosherinata.internal:9000
|
||||
SCCACHE_REGION: auto
|
||||
SCCACHE_S3_USE_SSL: "false"
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
# Escalation mirrors the lint/test jobs: retry → restart the
|
||||
# sccache server → final attempt uncached.
|
||||
#
|
||||
# The CUDA image may or may not ship sccache — probe inside this
|
||||
# step (NOT via GITHUB_ENV from a prior step, which this runner
|
||||
# does not propagate; observed: probe step said "enabled", build
|
||||
# ran unwrapped, server stats showed 4 compile requests). A
|
||||
# missing binary degrades to an uncached build rather than
|
||||
# failing cargo at `sccache rustc -vV`. The cache covers the
|
||||
# ~600-crate host-side dep tree (the bulk of the 10-14 min
|
||||
# build); rustc compilations are shared across all three
|
||||
# flavours, so even one run seeds the next.
|
||||
- name: Build neuron with CUDA (${{ matrix.flavour }})
|
||||
run: |
|
||||
set -ux
|
||||
if command -v sccache >/dev/null 2>&1; then
|
||||
export RUSTC_WRAPPER=sccache
|
||||
sccache --start-server 2>/dev/null || true
|
||||
echo "sccache enabled"
|
||||
else
|
||||
echo "sccache not on PATH — building uncached"
|
||||
fi
|
||||
export PATH="${{ matrix.cuda_home }}/bin:${PATH}"
|
||||
export LD_LIBRARY_PATH="${{ matrix.cuda_home }}/targets/x86_64-linux/lib:${{ matrix.cuda_home }}/lib64:${LD_LIBRARY_PATH:-}"
|
||||
export LIBRARY_PATH="${{ matrix.cuda_home }}/targets/x86_64-linux/lib:${{ matrix.cuda_home }}/lib64:${LIBRARY_PATH:-}"
|
||||
# Pin the build SHA neuron reports from GET /version. The git
|
||||
# fallback in build.rs would also work on a full checkout, but
|
||||
# injecting the exact checked-out commit is unambiguous under
|
||||
# shallow/detached states and makes the artifact self-describing.
|
||||
export HELEXA_BUILD_SHA="$(git rev-parse HEAD)"
|
||||
for attempt in 1 2 3; do
|
||||
echo "::group::build attempt ${attempt}"
|
||||
if [ "${attempt}" -eq 3 ]; then
|
||||
echo "final attempt: building without sccache"
|
||||
export RUSTC_WRAPPER=""
|
||||
fi
|
||||
if cargo build --release -p neuron --features "${{ matrix.cargo_features }}"; then
|
||||
echo "::endgroup::"
|
||||
command -v sccache >/dev/null 2>&1 && sccache --show-stats || true
|
||||
exit 0
|
||||
fi
|
||||
echo "::endgroup::"
|
||||
echo "build failed on attempt ${attempt}"
|
||||
if [ "${attempt}" -eq 1 ] && command -v sccache >/dev/null 2>&1; then
|
||||
sccache --stop-server || true
|
||||
sccache --start-server || true
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
echo "build failed after 3 attempts"
|
||||
exit 1
|
||||
env:
|
||||
CUDA_COMPUTE_CAP: ${{ matrix.compute_cap }}
|
||||
CARGO_BUILD_JOBS: ${{ matrix.build_jobs }}
|
||||
NVCC_THREADS: ${{ matrix.nvcc_threads }}
|
||||
|
||||
- name: Stage binary
|
||||
run: |
|
||||
mkdir --parents artifacts
|
||||
cp target/release/neuron artifacts/neuron-${{ matrix.flavour }}
|
||||
file "artifacts/neuron-${{ matrix.flavour }}"
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: neuron-${{ matrix.flavour }}-fc43
|
||||
path: artifacts/neuron-${{ matrix.flavour }}
|
||||
retention-days: 1
|
||||
|
||||
package-cortex:
|
||||
name: Package cortex RPM
|
||||
needs: [prepare, build-cortex]
|
||||
runs-on: rpm
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
- uses: actions/download-artifact@v3
|
||||
with:
|
||||
name: cortex-fc43
|
||||
path: artifacts/
|
||||
|
||||
- name: Build RPM
|
||||
run: |
|
||||
set -eux
|
||||
rm -f ~/.rpmmacros
|
||||
rpmdev-setuptree
|
||||
cp artifacts/cortex ~/rpmbuild/SOURCES/
|
||||
cp data/cortex.service ~/rpmbuild/SOURCES/
|
||||
cp data/cortex-sysusers.conf ~/rpmbuild/SOURCES/
|
||||
cp data/cortex-firewalld.xml ~/rpmbuild/SOURCES/
|
||||
cp cortex.example.toml ~/rpmbuild/SOURCES/
|
||||
cp models.example.toml ~/rpmbuild/SOURCES/
|
||||
cp LICENSE ~/rpmbuild/SOURCES/
|
||||
rpmbuild -bb rpm/cortex-prerelease.spec \
|
||||
--define "cortex_version ${{ needs.prepare.outputs.version }}" \
|
||||
--define "cortex_prerelease ${{ needs.prepare.outputs.release }}" \
|
||||
--undefine dist \
|
||||
--define "dist .fc43"
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: rpm-cortex-fc43
|
||||
path: ~/rpmbuild/RPMS/x86_64/*.rpm
|
||||
retention-days: 7
|
||||
|
||||
package-bench:
|
||||
name: Package helexa-bench RPM
|
||||
needs: [prepare, build-bench]
|
||||
runs-on: rpm
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
- uses: actions/download-artifact@v3
|
||||
with:
|
||||
name: bench-fc43
|
||||
path: artifacts/
|
||||
|
||||
- name: Build RPM
|
||||
run: |
|
||||
set -eux
|
||||
rm -f ~/.rpmmacros
|
||||
rpmdev-setuptree
|
||||
cp artifacts/helexa-bench ~/rpmbuild/SOURCES/
|
||||
cp data/helexa-bench.service ~/rpmbuild/SOURCES/
|
||||
cp data/helexa-bench-sysusers.conf ~/rpmbuild/SOURCES/
|
||||
cp helexa-bench.example.toml ~/rpmbuild/SOURCES/
|
||||
cp LICENSE ~/rpmbuild/SOURCES/
|
||||
rpmbuild -bb rpm/helexa-bench-prerelease.spec \
|
||||
--define "bench_version ${{ needs.prepare.outputs.version }}" \
|
||||
--define "bench_prerelease ${{ needs.prepare.outputs.release }}" \
|
||||
--undefine dist \
|
||||
--define "dist .fc43"
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: rpm-bench-fc43
|
||||
path: ~/rpmbuild/RPMS/x86_64/*.rpm
|
||||
retention-days: 7
|
||||
|
||||
package-neuron:
|
||||
name: Package helexa-neuron-${{ matrix.flavour }} RPM
|
||||
needs: [prepare, build-neuron]
|
||||
runs-on: rpm
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- flavour: ampere
|
||||
- flavour: ada
|
||||
- flavour: blackwell
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
- uses: actions/download-artifact@v3
|
||||
with:
|
||||
name: neuron-${{ matrix.flavour }}-fc43
|
||||
path: artifacts/
|
||||
|
||||
- name: Build RPM
|
||||
run: |
|
||||
set -eux
|
||||
rm -f ~/.rpmmacros
|
||||
rpmdev-setuptree
|
||||
cp artifacts/neuron-${{ matrix.flavour }} ~/rpmbuild/SOURCES/
|
||||
cp data/neuron.service ~/rpmbuild/SOURCES/
|
||||
cp data/neuron-sysusers.conf ~/rpmbuild/SOURCES/
|
||||
cp data/neuron-firewalld.xml ~/rpmbuild/SOURCES/
|
||||
cp neuron.example.toml ~/rpmbuild/SOURCES/
|
||||
cp LICENSE ~/rpmbuild/SOURCES/
|
||||
rpmbuild -bb rpm/helexa-neuron-prerelease.spec \
|
||||
--define "neuron_version ${{ needs.prepare.outputs.version }}" \
|
||||
--define "neuron_flavour ${{ matrix.flavour }}" \
|
||||
--define "neuron_prerelease ${{ needs.prepare.outputs.release }}" \
|
||||
--undefine dist \
|
||||
--define "dist .fc43"
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: rpm-neuron-${{ matrix.flavour }}-fc43
|
||||
path: ~/rpmbuild/RPMS/x86_64/*.rpm
|
||||
retention-days: 7
|
||||
|
||||
publish:
|
||||
name: Publish to rpm.lair.cafe (unstable)
|
||||
needs: [lint, test, package-cortex, package-neuron, package-bench]
|
||||
# Runs when at least one package was built and nothing failed.
|
||||
# lint/test may be skipped (docs-only refs never get here because
|
||||
# no packages build), but a real failure in any blocks the
|
||||
# fleet from receiving the RPMs.
|
||||
if: >-
|
||||
${{
|
||||
!cancelled()
|
||||
&& (needs.lint.result == 'success' || needs.lint.result == 'skipped')
|
||||
&& (needs.test.result == 'success' || needs.test.result == 'skipped')
|
||||
&& (needs.package-cortex.result == 'success' || needs.package-neuron.result == 'success' || needs.package-bench.result == 'success')
|
||||
&& needs.package-cortex.result != 'failure'
|
||||
&& needs.package-neuron.result != 'failure'
|
||||
&& needs.package-bench.result != 'failure'
|
||||
}}
|
||||
runs-on: rpm
|
||||
concurrency:
|
||||
group: rpm-publish
|
||||
cancel-in-progress: false
|
||||
env:
|
||||
RPM_REPO_HOST: oolon.kosherinata.internal
|
||||
FEDORA_VERSION: "43"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
- name: Download all built RPMs
|
||||
uses: actions/download-artifact@v3
|
||||
with:
|
||||
path: rpms/
|
||||
pattern: rpm-*-fc43
|
||||
|
||||
- name: Flatten RPM artifacts
|
||||
run: |
|
||||
set -eux
|
||||
find rpms/ -name '*.rpm' -exec mv --target-directory=rpms/ {} +
|
||||
find rpms/ -mindepth 1 -type d -empty -delete
|
||||
ls -la rpms/
|
||||
|
||||
- name: Check for sequoia-sq
|
||||
run: |
|
||||
if ! command -v sq &> /dev/null; then
|
||||
echo "ERROR: sequoia-sq is not installed. Install with: sudo dnf install sequoia-sq"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Import signing key
|
||||
env:
|
||||
# Pass secrets via env so values stay out of the rendered shell
|
||||
# script (which Gitea includes in step logs). Template
|
||||
# expansion of ${{ secrets.X }} inside `run:` writes the literal
|
||||
# value into the script and depends on Gitea's log masker to
|
||||
# scrub it — fragile for multi-line keys.
|
||||
RPM_SIGNING_KEY: ${{ secrets.RPM_SIGNING_KEY }}
|
||||
RPM_SIGNING_KEY_ID: ${{ secrets.RPM_SIGNING_KEY_ID }}
|
||||
run: |
|
||||
echo "$RPM_SIGNING_KEY" | gpg --batch --import
|
||||
fpr=$(gpg --batch --with-colons --list-keys "$RPM_SIGNING_KEY_ID" | awk -F: '/^fpr:/ { print $10; exit }')
|
||||
echo "${fpr}:6:" | gpg --batch --import-ownertrust
|
||||
sed "s/@GPG_NAME@/$RPM_SIGNING_KEY_ID/" rpm/rpmmacros > ~/.rpmmacros
|
||||
|
||||
- name: Sign RPMs
|
||||
run: |
|
||||
set -eux
|
||||
for rpm in rpms/*.rpm; do
|
||||
echo "signing ${rpm}..."
|
||||
rpm --addsign "${rpm}"
|
||||
done
|
||||
|
||||
- name: Set up SSH for rsync
|
||||
run: |
|
||||
install --directory --mode 700 ~/.ssh
|
||||
echo "${RSYNC_SSH_KEY}" | install --mode 600 /dev/stdin ~/.ssh/id_ed25519
|
||||
env:
|
||||
RSYNC_SSH_KEY: ${{ secrets.RSYNC_SSH_KEY }}
|
||||
|
||||
- name: Test SSH connectivity
|
||||
run: |
|
||||
ssh -o StrictHostKeyChecking=accept-new "gitea_ci@${RPM_REPO_HOST}" exit
|
||||
|
||||
- name: Ensure unstable repo directory exists
|
||||
run: |
|
||||
ssh "gitea_ci@${RPM_REPO_HOST}" \
|
||||
"mkdir --parents /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable"
|
||||
|
||||
- name: Sync RPMs to unstable repo
|
||||
run: |
|
||||
rsync \
|
||||
--archive \
|
||||
--verbose \
|
||||
--chmod D755,F644 \
|
||||
rpms/*.rpm \
|
||||
"gitea_ci@${RPM_REPO_HOST}:/var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/"
|
||||
|
||||
- name: Update unstable repo metadata
|
||||
run: |
|
||||
ssh "gitea_ci@${RPM_REPO_HOST}" \
|
||||
"cd /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable && createrepo_c --update ."
|
||||
|
||||
- name: Generate packages.json manifest
|
||||
run: |
|
||||
scp script/generate-packages-json.py "gitea_ci@${RPM_REPO_HOST}:/tmp/"
|
||||
ssh "gitea_ci@${RPM_REPO_HOST}" \
|
||||
"python3 /tmp/generate-packages-json.py \
|
||||
--repodata-dir /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/repodata \
|
||||
--output /var/www/rpm/fedora/${FEDORA_VERSION}/x86_64/unstable/packages.json \
|
||||
--base-url https://rpm.lair.cafe/fedora/${FEDORA_VERSION}/x86_64/unstable"
|
||||
@@ -1,26 +1,12 @@
|
||||
name: CI
|
||||
|
||||
# Pushes to main are deliberately excluded: build-prerelease.yml runs
|
||||
# its own lint/test jobs there (gating publish), and running both
|
||||
# workflows on the same push made them queue against each other on the
|
||||
# same runner labels — ~12 minutes of added latency per deploy. Feature
|
||||
# branches, PRs to main, and release tags keep the full gate here.
|
||||
on:
|
||||
push:
|
||||
branches-ignore: [main]
|
||||
branches: ["**"]
|
||||
tags: ["v*"]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
# Coalesce same-ref pushes; a newer push supersedes the in-flight run.
|
||||
# (The old shared `cortex-runner-pool` group with build-prerelease.yml
|
||||
# is gone — the workflows no longer trigger on the same refs, and
|
||||
# ephemeral one-VM-per-job runners removed the shared-workspace race
|
||||
# that group existed to serialize.)
|
||||
concurrency:
|
||||
group: ci-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
CARGO_INCREMENTAL: "0"
|
||||
RUSTC_WRAPPER: sccache
|
||||
@@ -30,163 +16,40 @@ env:
|
||||
SCCACHE_S3_USE_SSL: "false"
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.SCCACHE_S3_ACCESS_KEY }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.SCCACHE_S3_SECRET_KEY }}
|
||||
# fmt, clippy, and test all run in parallel on the same `rust` runner
|
||||
# and would otherwise share /root/.cache/act/<hash>/hostexecutor/target/,
|
||||
# racing each other's cargo temp files (.tmpXXXXXX) and failing builds
|
||||
# mid-compile. Give each job its own target directory so the invocations
|
||||
# don't collide. sccache still backs the actual rustc cache, so the
|
||||
# rebuild penalty is small.
|
||||
CARGO_TARGET_DIR: target-${{ github.job }}
|
||||
|
||||
jobs:
|
||||
fmt:
|
||||
name: Format
|
||||
runs-on: rust
|
||||
check:
|
||||
name: Format, lint, build, test
|
||||
runs-on: fedora
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- run: cargo fmt --check --all
|
||||
|
||||
clippy:
|
||||
name: Clippy
|
||||
runs-on: rust
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
# sccache failures come in two modes: transient races (a plain
|
||||
# retry clears them) and a wedged/dead server, where every
|
||||
# same-VM retry fails identically. Escalate: retry → restart the
|
||||
# server → final attempt uncached. A sick cache costs build
|
||||
# time, never the run. Keep in sync with build-prerelease.yml.
|
||||
- name: Clippy (with sccache escalation)
|
||||
- name: Ensure sccache with S3 support
|
||||
env:
|
||||
RUSTC_WRAPPER: ""
|
||||
run: |
|
||||
for attempt in 1 2 3; do
|
||||
echo "::group::clippy attempt ${attempt}"
|
||||
if [ "${attempt}" -eq 3 ]; then
|
||||
echo "final attempt: building without sccache"
|
||||
export RUSTC_WRAPPER=""
|
||||
fi
|
||||
if cargo clippy --workspace -- -D warnings; then
|
||||
echo "::endgroup::"
|
||||
exit 0
|
||||
fi
|
||||
echo "::endgroup::"
|
||||
echo "clippy failed on attempt ${attempt}"
|
||||
if [ "${attempt}" -eq 1 ]; then
|
||||
sccache --stop-server || true
|
||||
sccache --start-server || true
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
echo "clippy failed after 3 attempts"
|
||||
exit 1
|
||||
- run: sccache --show-stats || true
|
||||
|
||||
test:
|
||||
name: Test
|
||||
runs-on: rust
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
# See the clippy job for the escalation rationale.
|
||||
- name: Test (with sccache escalation)
|
||||
run: |
|
||||
for attempt in 1 2 3; do
|
||||
echo "::group::test attempt ${attempt}"
|
||||
if [ "${attempt}" -eq 3 ]; then
|
||||
echo "final attempt: building without sccache"
|
||||
export RUSTC_WRAPPER=""
|
||||
fi
|
||||
if cargo test --workspace; then
|
||||
echo "::endgroup::"
|
||||
exit 0
|
||||
fi
|
||||
echo "::endgroup::"
|
||||
echo "test failed on attempt ${attempt}"
|
||||
if [ "${attempt}" -eq 1 ]; then
|
||||
sccache --stop-server || true
|
||||
sccache --start-server || true
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
echo "test failed after 3 attempts"
|
||||
exit 1
|
||||
- run: sccache --show-stats || true
|
||||
|
||||
# Type-check the CUDA-only code path. Borrow-check-only — we
|
||||
# never run the tests here (the runner has no GPU). This catches
|
||||
# the category of bug where a refactor compiles fine under the
|
||||
# default feature set (which is what the `clippy` and `test` jobs
|
||||
# exercise) but fails inside a `#[cfg(feature = "cuda")]` block.
|
||||
# `runs-on: cuda-13.0` selects the runner that ships nvcc /
|
||||
# cudarc's build prerequisites. The generic `rust` and `rpm`
|
||||
# runners don't have them (the previous label `rpm` was tried
|
||||
# first and tripped cudarc's `nvcc --version` build script —
|
||||
# see commit history).
|
||||
cuda-check:
|
||||
name: CUDA type-check
|
||||
runs-on: cuda-13.0
|
||||
# The workflow-level env sets `RUSTC_WRAPPER: sccache`
|
||||
# unconditionally, which hard-fails cargo if the CUDA image
|
||||
# doesn't ship sccache. Clear it at job level; the "Enable
|
||||
# sccache when available" step opts back in only after probing
|
||||
# for the binary. SCCACHE_*/AWS creds stay set — harmless when
|
||||
# the wrapper is off, required when it's on.
|
||||
env:
|
||||
RUSTC_WRAPPER: ""
|
||||
# candle-kernels' build script falls back to `nvidia-smi` for
|
||||
# compute-cap detection when this is unset — and the GPU-less
|
||||
# builder image doesn't ship nvidia-smi. Any valid cap works for
|
||||
# a borrow-check; the real per-flavour caps live in
|
||||
# build-prerelease.yml's matrix.
|
||||
CUDA_COMPUTE_CAP: "86"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
# sccache is probed inside this step (NOT via GITHUB_ENV from a
|
||||
# prior step — this runner doesn't propagate it; see
|
||||
# build-prerelease.yml for the observed failure).
|
||||
- name: cargo check --features cuda (with sccache escalation)
|
||||
run: |
|
||||
if command -v sccache >/dev/null 2>&1; then
|
||||
export RUSTC_WRAPPER=sccache
|
||||
sccache --start-server 2>/dev/null || true
|
||||
echo "sccache enabled"
|
||||
if sccache --version 2>/dev/null && sccache --show-stats 2>/dev/null; then
|
||||
echo "sccache with S3 support already installed"
|
||||
else
|
||||
echo "sccache not on PATH — building uncached"
|
||||
cargo install sccache --features s3 --locked
|
||||
fi
|
||||
# act launches the step shell without /etc/profile, so the
|
||||
# gitea_runner user's inherited PATH lacks /usr/local/cuda-13.0/bin.
|
||||
# cudarc's build.rs:157 shells out to `nvcc --version` (because
|
||||
# the neuron crate enables cuda-version-from-build-system) and
|
||||
# panics with ENOENT if nvcc isn't resolvable. build-prerelease.yml
|
||||
# does the same export — keep them in sync.
|
||||
export PATH="/usr/local/cuda-13.0/bin:${PATH}"
|
||||
export LD_LIBRARY_PATH="/usr/local/cuda-13.0/targets/x86_64-linux/lib:/usr/local/cuda-13.0/lib64:${LD_LIBRARY_PATH:-}"
|
||||
export LIBRARY_PATH="/usr/local/cuda-13.0/targets/x86_64-linux/lib:/usr/local/cuda-13.0/lib64:${LIBRARY_PATH:-}"
|
||||
# Escalation mirrors the lint/test jobs: plain retry →
|
||||
# sccache server restart → final attempt uncached.
|
||||
for attempt in 1 2 3; do
|
||||
echo "::group::cuda-check attempt ${attempt}"
|
||||
if [ "${attempt}" -eq 3 ]; then
|
||||
echo "final attempt: building without sccache"
|
||||
export RUSTC_WRAPPER=""
|
||||
fi
|
||||
if cargo check -p neuron --features cuda --all-targets; then
|
||||
echo "::endgroup::"
|
||||
exit 0
|
||||
fi
|
||||
echo "::endgroup::"
|
||||
echo "cuda-check failed on attempt ${attempt}"
|
||||
if [ "${attempt}" -eq 1 ] && command -v sccache >/dev/null 2>&1; then
|
||||
sccache --stop-server || true
|
||||
sccache --start-server || true
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
echo "cuda-check failed after 3 attempts"
|
||||
exit 1
|
||||
|
||||
- name: Check formatting
|
||||
run: cargo fmt --check --all
|
||||
|
||||
- name: Clippy
|
||||
run: cargo clippy --workspace -- -D warnings
|
||||
|
||||
- name: Test
|
||||
run: cargo test --workspace
|
||||
|
||||
- name: Show sccache stats
|
||||
run: sccache --show-stats
|
||||
|
||||
srpm-cortex:
|
||||
name: Build cortex SRPM
|
||||
runs-on: rpm
|
||||
needs: [fmt, clippy, test, cuda-check]
|
||||
runs-on: fedora
|
||||
needs: check
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -245,8 +108,8 @@ jobs:
|
||||
|
||||
srpm-neuron:
|
||||
name: Build neuron SRPM
|
||||
runs-on: rpm
|
||||
needs: [fmt, clippy, test, cuda-check]
|
||||
runs-on: fedora
|
||||
needs: check
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -305,7 +168,7 @@ jobs:
|
||||
|
||||
copr-cortex:
|
||||
name: Publish cortex to COPR
|
||||
runs-on: fedora-43
|
||||
runs-on: fedora
|
||||
needs: srpm-cortex
|
||||
steps:
|
||||
- name: Download SRPM
|
||||
@@ -322,7 +185,7 @@ jobs:
|
||||
|
||||
copr-neuron:
|
||||
name: Publish neuron to COPR
|
||||
runs-on: fedora-43
|
||||
runs-on: fedora
|
||||
needs: srpm-neuron
|
||||
steps:
|
||||
- name: Download SRPM
|
||||
@@ -339,7 +202,7 @@ jobs:
|
||||
|
||||
bump-version:
|
||||
name: Bump version in source
|
||||
runs-on: rust
|
||||
runs-on: fedora
|
||||
needs: [copr-cortex, copr-neuron]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -382,6 +245,6 @@ jobs:
|
||||
echo "Nothing to commit for ${VERSION}"
|
||||
else
|
||||
git commit -m "chore: bump version to ${VERSION}"
|
||||
git remote set-url origin "https://gitea-actions:${GITEA_TOKEN}@git.lair.cafe/${{ github.repository }}.git"
|
||||
git remote set-url origin "https://gitea-actions:${GITEA_TOKEN}@git.lair.cafe/helexa/cortex.git"
|
||||
git push origin HEAD:main
|
||||
fi
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
name: deploy-dev
|
||||
|
||||
# Fast-path iteration deploy for a SINGLE neuron host: build one CUDA
|
||||
# flavour, copy the raw binary to the host, restart neuron.service.
|
||||
# Skips the other two flavours, all RPM packaging, signing, repo
|
||||
# publish, and dnf — push-to-testable drops from ~20 min to roughly
|
||||
# one CUDA build plus a service restart.
|
||||
#
|
||||
# This is a DEV convenience, not a release path:
|
||||
# - the binary lands at /usr/bin/neuron *outside* RPM ownership;
|
||||
# the next regular deploy.yml run reconciles the host back to the
|
||||
# packaged binary (dnf sees the newer RPM and reinstalls). `rpm -V
|
||||
# helexa-neuron-<flavour>` flagging a modified /usr/bin/neuron in
|
||||
# the interim is expected.
|
||||
# - nothing is published; other hosts are untouched.
|
||||
# - requires the `install` sudoers rule from
|
||||
# asset/sudoers.d/neuron-host.conf (re-run script/infra-setup.sh
|
||||
# after updating it).
|
||||
#
|
||||
# Trigger from the Gitea UI: Actions → deploy-dev → Run workflow,
|
||||
# pick the target host. Defaults to the ref you dispatch from, so it
|
||||
# works from feature branches without touching main.
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
target:
|
||||
description: "neuron host to deploy to"
|
||||
required: true
|
||||
type: choice
|
||||
options: [beast, benjy, quadbrat]
|
||||
default: beast
|
||||
|
||||
# One dev deploy at a time; a newer dispatch for the same host wins.
|
||||
concurrency:
|
||||
group: deploy-dev-${{ inputs.target }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
CARGO_INCREMENTAL: "0"
|
||||
CARGO_TERM_COLOR: "always"
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build neuron (${{ inputs.target }})
|
||||
runs-on: cuda-13.0
|
||||
outputs:
|
||||
flavour: ${{ steps.map.outputs.flavour }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
# host → flavour → compute cap. Keep in sync with the
|
||||
# build-neuron matrix in build-prerelease.yml and the
|
||||
# deploy-neurons matrix in deploy.yml.
|
||||
- id: map
|
||||
run: |
|
||||
case "${{ inputs.target }}" in
|
||||
beast) flavour=blackwell cap=120 ;;
|
||||
benjy) flavour=ada cap=89 ;;
|
||||
quadbrat) flavour=ampere cap=86 ;;
|
||||
*) echo "unknown target ${{ inputs.target }}"; exit 1 ;;
|
||||
esac
|
||||
echo "flavour=${flavour}" >> "$GITHUB_OUTPUT"
|
||||
echo "cap=${cap}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Build neuron with CUDA
|
||||
run: |
|
||||
set -eux
|
||||
export PATH="/usr/local/cuda-13.0/bin:${PATH}"
|
||||
export LD_LIBRARY_PATH="/usr/local/cuda-13.0/targets/x86_64-linux/lib:/usr/local/cuda-13.0/lib64:${LD_LIBRARY_PATH:-}"
|
||||
export LIBRARY_PATH="/usr/local/cuda-13.0/targets/x86_64-linux/lib:/usr/local/cuda-13.0/lib64:${LIBRARY_PATH:-}"
|
||||
cargo build --release -p neuron --features "cuda cudnn"
|
||||
env:
|
||||
CUDA_COMPUTE_CAP: ${{ steps.map.outputs.cap }}
|
||||
CARGO_BUILD_JOBS: "8"
|
||||
NVCC_THREADS: "4"
|
||||
|
||||
- name: Stage binary
|
||||
run: |
|
||||
mkdir --parents artifacts
|
||||
cp target/release/neuron artifacts/neuron-dev
|
||||
file artifacts/neuron-dev
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: neuron-dev-${{ inputs.target }}
|
||||
path: artifacts/neuron-dev
|
||||
retention-days: 1
|
||||
|
||||
deploy:
|
||||
name: Deploy to ${{ inputs.target }}
|
||||
needs: build
|
||||
runs-on: fedora-43
|
||||
env:
|
||||
DEPLOY_KEY: |
|
||||
${{ secrets.RSYNC_SSH_KEY }}
|
||||
TARGET_HOST: ${{ inputs.target }}.hanzalova.internal
|
||||
steps:
|
||||
- name: SSH init
|
||||
run: |
|
||||
mkdir -p ~/.ssh
|
||||
echo "${DEPLOY_KEY}" > ~/.ssh/id_ed25519
|
||||
chmod 600 ~/.ssh/id_ed25519
|
||||
ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=accept-new \
|
||||
"gitea_ci@${TARGET_HOST}" 'hostname -f'
|
||||
|
||||
- uses: actions/download-artifact@v3
|
||||
with:
|
||||
name: neuron-dev-${{ inputs.target }}
|
||||
path: artifacts/
|
||||
|
||||
- name: Copy binary to host
|
||||
run: |
|
||||
scp artifacts/neuron-dev "gitea_ci@${TARGET_HOST}:/var/lib/gitea_ci/neuron-dev"
|
||||
|
||||
- name: Install binary and restart neuron.service
|
||||
run: |
|
||||
ssh "gitea_ci@${TARGET_HOST}" '
|
||||
set -eu
|
||||
if systemctl is-active --quiet neuron.service; then
|
||||
sudo /usr/bin/systemctl stop neuron.service
|
||||
fi
|
||||
# Exact command form required by the sudoers rule in
|
||||
# asset/sudoers.d/neuron-host.conf — change both together.
|
||||
sudo /usr/bin/install -o root -g root -m 0755 /var/lib/gitea_ci/neuron-dev /usr/bin/neuron
|
||||
sudo /usr/bin/systemctl start neuron.service
|
||||
rm -f /var/lib/gitea_ci/neuron-dev'
|
||||
|
||||
- name: Capture neuron.service startup journal
|
||||
if: always()
|
||||
run: |
|
||||
sleep 10
|
||||
ssh "gitea_ci@${TARGET_HOST}" \
|
||||
'journalctl --unit neuron.service -I --no-pager'
|
||||
@@ -1,252 +0,0 @@
|
||||
name: deploy
|
||||
|
||||
# Roll the freshly-published unstable RPMs onto the helexa fleet:
|
||||
# cortex on the gateway, helexa-neuron-<flavour> on each neuron host.
|
||||
#
|
||||
# Triggered automatically after `build-prerelease` succeeds (by which
|
||||
# point the new RPMs are live on rpm.lair.cafe/unstable), and also
|
||||
# re-runnable manually from the Gitea UI.
|
||||
#
|
||||
# Each host self-gates: if dnf sees no newer package than what is
|
||||
# installed, the service is left alone — no stop, no restart, no model
|
||||
# cold-load. Combined with build-prerelease's change detection this
|
||||
# means a docs- or gateway-only push never restarts the neurons (a
|
||||
# neuron restart costs ~5 min of 27B cold-load, see issue #1).
|
||||
#
|
||||
# Per-host one-time setup (gitea_ci user, authorized_keys, scoped
|
||||
# sudoers drop-in) lives in script/infra-setup.sh — run that once per
|
||||
# host before this workflow can succeed.
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: [build-prerelease]
|
||||
types: [completed]
|
||||
workflow_dispatch:
|
||||
|
||||
# Serialize deploys. Overlapping runs would race on dnf metadata
|
||||
# refresh and service-restart timing; queueing keeps the fleet
|
||||
# predictable. Don't cancel an in-flight deploy — a half-applied dnf
|
||||
# transaction is worse than a slightly stale deploy.
|
||||
concurrency:
|
||||
group: deploy
|
||||
cancel-in-progress: false
|
||||
|
||||
env:
|
||||
DEPLOY_KEY: |
|
||||
${{ secrets.RSYNC_SSH_KEY }}
|
||||
|
||||
jobs:
|
||||
deploy-cortex:
|
||||
runs-on: fedora-43
|
||||
# Two trigger paths: manual dispatch always runs; workflow_run
|
||||
# only runs if the upstream `build-prerelease` actually succeeded.
|
||||
if: >-
|
||||
${{
|
||||
github.event_name == 'workflow_dispatch'
|
||||
|| github.event.workflow_run.conclusion == 'success'
|
||||
}}
|
||||
steps:
|
||||
- name: SSH init
|
||||
run: |
|
||||
mkdir -p ~/.ssh
|
||||
echo "${DEPLOY_KEY}" > ~/.ssh/id_ed25519
|
||||
chmod 600 ~/.ssh/id_ed25519
|
||||
ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=accept-new \
|
||||
gitea_ci@hanzalova.internal 'hostname -f'
|
||||
|
||||
# Gating compares `rpm -q` against the packages.json manifest the
|
||||
# publish job maintains — NOT unprivileged `dnf check-update`,
|
||||
# which proved unreliable as the gitea_ci user (hung on metadata
|
||||
# locks on one host, silently reported "no updates" on others).
|
||||
# An unreadable/unparsable manifest fails open: deploy proceeds.
|
||||
- name: Deploy cortex (skips when already current)
|
||||
run: |
|
||||
ssh gitea_ci@hanzalova.internal 'bash -s' <<'DEPLOY'
|
||||
set -eu
|
||||
pkg=cortex
|
||||
installed=$(rpm -q --qf '%{VERSION}-%{RELEASE}' "${pkg}" 2>/dev/null || echo "not-installed")
|
||||
latest=$(curl -fsS --max-time 15 "https://rpm.lair.cafe/fedora/43/x86_64/unstable/packages.json" 2>/dev/null \
|
||||
| python3 -c '
|
||||
import json, sys
|
||||
name = sys.argv[1]
|
||||
cands = [p for p in json.load(sys.stdin)["packages"] if p.get("name") == name]
|
||||
if cands:
|
||||
p = max(cands, key=lambda p: p.get("buildTime", 0))
|
||||
print(p["version"] + "-" + p["release"])
|
||||
' "${pkg}" 2>/dev/null || true)
|
||||
if [ -n "${latest}" ] && [ "${latest}" = "${installed}" ]; then
|
||||
echo "${pkg}-${installed} already current — leaving service untouched"
|
||||
exit 0
|
||||
fi
|
||||
echo "installed=${installed} published=${latest:-unknown} — deploying"
|
||||
if systemctl is-active --quiet cortex.service; then
|
||||
sudo /usr/bin/systemctl stop cortex.service
|
||||
fi
|
||||
if rpm -q "${pkg}" >/dev/null 2>&1; then
|
||||
sudo /usr/bin/dnf upgrade --refresh --allowerasing -y cortex
|
||||
else
|
||||
sudo /usr/bin/dnf install --refresh --allowerasing -y cortex
|
||||
fi
|
||||
sudo /usr/bin/systemctl daemon-reload
|
||||
sudo /usr/bin/systemctl start cortex.service
|
||||
DEPLOY
|
||||
|
||||
# Wait for the service to either come up or wedge, then capture
|
||||
# the latest-invocation journal. Runs even on prior failure so a
|
||||
# failed start step still leaves a usable record in the deploy log.
|
||||
- name: Capture cortex.service startup journal
|
||||
if: always()
|
||||
run: |
|
||||
sleep 10
|
||||
ssh gitea_ci@hanzalova.internal \
|
||||
'journalctl --unit cortex.service -I --no-pager'
|
||||
|
||||
deploy-neurons:
|
||||
needs: [deploy-cortex]
|
||||
runs-on: fedora-43
|
||||
strategy:
|
||||
# One neuron failing must not cancel the others. Cortex is up
|
||||
# already; a partial neuron deploy is strictly better than
|
||||
# rolling back to zero.
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
# load_timeout: how long to wait for default_models to finish
|
||||
# loading after a restart. beast cold-loads Qwen3.6-27B Q6K
|
||||
# TP=2 (~5-6 min typical, see #1); benjy/quadbrat load small
|
||||
# single-GPU models in well under a minute.
|
||||
- host: beast.hanzalova.internal
|
||||
flavour: blackwell
|
||||
load_timeout: 900
|
||||
- host: benjy.hanzalova.internal
|
||||
flavour: ada
|
||||
load_timeout: 300
|
||||
- host: quadbrat.hanzalova.internal
|
||||
flavour: ampere
|
||||
load_timeout: 300
|
||||
steps:
|
||||
- name: SSH init
|
||||
run: |
|
||||
mkdir -p ~/.ssh
|
||||
echo "${DEPLOY_KEY}" > ~/.ssh/id_ed25519
|
||||
chmod 600 ~/.ssh/id_ed25519
|
||||
ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=accept-new \
|
||||
gitea_ci@${{ matrix.host }} 'hostname -f'
|
||||
|
||||
# See deploy-cortex for why gating uses the publish manifest and
|
||||
# not unprivileged `dnf check-update`.
|
||||
- name: Deploy helexa-neuron-${{ matrix.flavour }} (skips when already current)
|
||||
run: |
|
||||
ssh gitea_ci@${{ matrix.host }} 'bash -s' <<'DEPLOY'
|
||||
set -eu
|
||||
pkg=helexa-neuron-${{ matrix.flavour }}
|
||||
installed=$(rpm -q --qf '%{VERSION}-%{RELEASE}' "${pkg}" 2>/dev/null || echo "not-installed")
|
||||
latest=$(curl -fsS --max-time 15 "https://rpm.lair.cafe/fedora/43/x86_64/unstable/packages.json" 2>/dev/null \
|
||||
| python3 -c '
|
||||
import json, sys
|
||||
name = sys.argv[1]
|
||||
cands = [p for p in json.load(sys.stdin)["packages"] if p.get("name") == name]
|
||||
if cands:
|
||||
p = max(cands, key=lambda p: p.get("buildTime", 0))
|
||||
print(p["version"] + "-" + p["release"])
|
||||
' "${pkg}" 2>/dev/null || true)
|
||||
if [ -n "${latest}" ] && [ "${latest}" = "${installed}" ]; then
|
||||
echo "${pkg}-${installed} already current — leaving service untouched"
|
||||
exit 0
|
||||
fi
|
||||
echo "installed=${installed} published=${latest:-unknown} — deploying"
|
||||
if systemctl is-active --quiet neuron.service; then
|
||||
sudo /usr/bin/systemctl stop neuron.service
|
||||
fi
|
||||
if rpm -q "${pkg}" >/dev/null 2>&1; then
|
||||
sudo /usr/bin/dnf upgrade --refresh --allowerasing -y "${pkg}"
|
||||
else
|
||||
sudo /usr/bin/dnf install --refresh --allowerasing -y "${pkg}"
|
||||
fi
|
||||
sudo /usr/bin/systemctl daemon-reload
|
||||
sudo /usr/bin/systemctl start neuron.service
|
||||
|
||||
# ── Post-deploy validation ────────────────────────────────
|
||||
# A deploy only goes green if the neuron (a) finishes loading
|
||||
# its default models and (b) answers a trivial prompt like an
|
||||
# LLM should. Catches the class of bug where the binary
|
||||
# starts fine but model load or inference is broken — which
|
||||
# previously surfaced only when a human noticed. The wait
|
||||
# polls /health activation (the structured source of the
|
||||
# "loaded default model" journal line, plus per-model failure
|
||||
# detail); the journal-capture step below still runs for
|
||||
# forensics either way.
|
||||
load_timeout=${{ matrix.load_timeout }}
|
||||
echo "waiting for default models (timeout ${load_timeout}s)"
|
||||
deadline=$(( $(date +%s) + load_timeout ))
|
||||
health=""
|
||||
while :; do
|
||||
health=$(curl -fsS --max-time 5 http://localhost:13131/health 2>/dev/null || true)
|
||||
state=$(printf %s "${health}" | python3 -c '
|
||||
import json, sys
|
||||
try:
|
||||
print(json.load(sys.stdin).get("activation", {}).get("state", ""))
|
||||
except Exception:
|
||||
print("")
|
||||
')
|
||||
if [ "${state}" = "ready" ]; then
|
||||
break
|
||||
fi
|
||||
if [ "$(date +%s)" -ge "${deadline}" ]; then
|
||||
echo "FAIL: activation not ready within ${load_timeout}s (last state: ${state:-unreachable})"
|
||||
exit 1
|
||||
fi
|
||||
sleep 10
|
||||
done
|
||||
|
||||
model=$(printf %s "${health}" | python3 -c '
|
||||
import json, sys
|
||||
a = json.load(sys.stdin).get("activation", {})
|
||||
failed = a.get("failed", [])
|
||||
if failed:
|
||||
for f in failed:
|
||||
msg = "FAILED " + str(f.get("model_id")) + ": " + str(f.get("error", ""))[:400]
|
||||
sys.stderr.write(msg + chr(10))
|
||||
sys.exit(1)
|
||||
completed = a.get("completed", [])
|
||||
print(completed[0] if completed else "")
|
||||
')
|
||||
if [ -z "${model}" ]; then
|
||||
echo "no default models configured — skipping LLM probe"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "LLM probe against ${model}"
|
||||
probe_body=$(printf '{"model":"%s","messages":[{"role":"user","content":"Reply with exactly one word: pineapple"}],"max_tokens":512,"temperature":0}' "${model}")
|
||||
resp=$(curl -fsS --max-time 180 -H "content-type: application/json" \
|
||||
-d "${probe_body}" http://localhost:13131/v1/chat/completions) || {
|
||||
echo "FAIL: probe request errored"
|
||||
exit 1
|
||||
}
|
||||
if printf %s "${resp}" | grep -qi pineapple; then
|
||||
echo "LLM probe passed"
|
||||
else
|
||||
echo "FAIL: probe response missing expected token"
|
||||
printf %s "${resp}" | head -c 2000
|
||||
echo
|
||||
exit 1
|
||||
fi
|
||||
DEPLOY
|
||||
|
||||
- name: Ensure firewalld allows helexa-neuron
|
||||
run: |
|
||||
ssh gitea_ci@${{ matrix.host }} '
|
||||
if ! sudo /usr/bin/firewall-cmd --query-service=helexa-neuron --quiet 2>/dev/null; then
|
||||
sudo /usr/bin/firewall-cmd --add-service=helexa-neuron --permanent
|
||||
sudo /usr/bin/firewall-cmd --reload
|
||||
fi'
|
||||
|
||||
# Wait for the service to either come up or wedge, then capture
|
||||
# the latest-invocation journal. Runs even on prior failure so a
|
||||
# failed start step still leaves a usable record in the deploy log.
|
||||
- name: Capture neuron.service startup journal
|
||||
if: always()
|
||||
run: |
|
||||
sleep 10
|
||||
ssh gitea_ci@${{ matrix.host }} \
|
||||
'journalctl --unit neuron.service -I --no-pager'
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,7 +4,4 @@
|
||||
.idea/
|
||||
.vscode/
|
||||
cortex.toml
|
||||
models.toml
|
||||
doc/plan/*
|
||||
/target-cuda/
|
||||
.claude/
|
||||
|
||||
237
CLAUDE.md
237
CLAUDE.md
@@ -1,26 +1,16 @@
|
||||
# CLAUDE.md — helexa
|
||||
# CLAUDE.md — cortex
|
||||
|
||||
## Project overview
|
||||
|
||||
helexa is a self-hosted LLM serving stack for multi-node GPU inference
|
||||
clusters. It has two components:
|
||||
|
||||
- **cortex** — the per-operator control plane and LLM proxy. A Rust
|
||||
reverse-proxy that sits in front of the fleet and presents a unified
|
||||
OpenAI + Anthropic compatible API surface. It handles model routing,
|
||||
lifecycle management (load/unload/evict), request translation, and
|
||||
metrics collection.
|
||||
- **neuron** — the per-host LLM harness. One instance runs on every GPU
|
||||
host, serving candle-based in-process inference and managing local
|
||||
hardware discovery and model lifecycle.
|
||||
|
||||
(Historical note: cortex originally proxied to mistral.rs nodes; neuron
|
||||
replaced that — see the 2026-05-18 candle-native addendum below.)
|
||||
cortex is a Rust reverse-proxy that sits in front of multiple
|
||||
mistral.rs inference nodes and presents a unified OpenAI + Anthropic
|
||||
compatible API surface. It handles model routing, lifecycle management
|
||||
(load/unload/evict), request translation, and metrics collection.
|
||||
|
||||
## Repository layout
|
||||
|
||||
```
|
||||
helexa/
|
||||
cortex/
|
||||
├── Cargo.toml # workspace root
|
||||
├── cortex.toml # example gateway config
|
||||
├── README.md
|
||||
@@ -94,63 +84,6 @@ Per-request: model, node, prompt_tokens, completion_tokens, total_tokens,
|
||||
tok_per_sec, time_to_first_token_ms, total_latency_ms.
|
||||
Exposed as Prometheus histograms/counters on a separate port.
|
||||
|
||||
### Per-device worker thread (neuron)
|
||||
The neuron daemon dedicates one OS thread per CUDA device it loads
|
||||
onto. That thread binds the device's `CudaContext` once at startup and
|
||||
owns it for the daemon's lifetime; every model load, forward step,
|
||||
KV-cache reset, VRAM query, NCCL init/sanity, NCCL all_reduce, and
|
||||
model drop on that device routes through this thread via a
|
||||
`std::sync::mpsc` job channel. Replies cross back via
|
||||
`tokio::sync::oneshot`.
|
||||
|
||||
Three properties this gives us, in order of weight:
|
||||
|
||||
1. **Context locality.** cudarc binds the CUDA context per OS thread
|
||||
via `cuCtxSetCurrent`. Before this refactor, ad-hoc
|
||||
`tokio::task::spawn_blocking` calls bound the context onto a
|
||||
different thread per request — and `device_vram_mb()` from an
|
||||
async task bound it onto whichever tokio worker happened to be
|
||||
running. Pinning the context to one named thread ends that.
|
||||
2. **Drop safety.** Every `CudaSlice` in a `Tensor`, every
|
||||
`cudarc::nccl::Comm`, and the `CudaContext` itself call `cuMemFree` /
|
||||
`ncclCommDestroy` / `cuCtxDestroy` during `Drop` — and require the
|
||||
right context current. With the worker owning the model slab,
|
||||
`Drop` always runs on the right thread. The cudarc Drop constraint
|
||||
is structurally enforced.
|
||||
3. **Poisoning blast radius.** When a CUDA driver error makes the
|
||||
context unrecoverable, the poison flag lives on the
|
||||
`DeviceWorkerHandle` itself. Subsequent `submit()` calls fast-reject
|
||||
at the channel boundary with a clear "device worker is poisoned"
|
||||
error before any further CUDA work is attempted. The thread doesn't
|
||||
exit (dropping the slab would re-touch the broken context) — it
|
||||
enters a drain-only mode and replies error to everything until the
|
||||
daemon restarts.
|
||||
|
||||
Tensors never escape the worker thread alive. Inference replies carry
|
||||
`Vec<f32>` CPU-side logits; the async caller wraps them in a CPU
|
||||
candle tensor and runs `apply_repeat_penalty` + `LogitsProcessor::sample`
|
||||
without ever rebinding the device context. Sampled tokens come back as
|
||||
`u32`; VRAM queries as `(u64, u64)`. The opaque `ArchHandle(u64)` and
|
||||
`TpHandle(u64)` are the only "references" callers hold to loaded
|
||||
models — they're indices into the worker's state slab, not pointers.
|
||||
|
||||
The TP worker subprocesses in `harness/tp/worker.rs` are the same
|
||||
pattern out-of-process — a dedicated context-owning process per
|
||||
non-zero NCCL rank. The in-process worker in `harness/device_worker/`
|
||||
brings the discipline to rank 0.
|
||||
|
||||
CPU loads (`Device::Cpu` fallback when CUDA is unavailable) keep the
|
||||
legacy `tokio::task::spawn_blocking + Arc<Mutex<ModelArch>>` path —
|
||||
there's no context to own and the channel hop would only add latency.
|
||||
Four `spawn_blocking` references in `harness/candle.rs` are deliberate
|
||||
CPU fallback.
|
||||
|
||||
Canonical narrative lives in
|
||||
`crates/neuron/src/harness/device_worker/mod.rs`'s module
|
||||
doc-comment; touch points (the `Job` enum, the dispatch handlers, the
|
||||
`DeviceWorkerState` struct) are in the sibling `jobs.rs` and
|
||||
`dispatch.rs`.
|
||||
|
||||
## Tech stack
|
||||
|
||||
- **Rust 2024 edition** — workspace with 4 crates
|
||||
@@ -558,7 +491,7 @@ and the hardcoded `vram_mb` per node.
|
||||
## Revised repository layout
|
||||
|
||||
```
|
||||
helexa/
|
||||
cortex/
|
||||
├── Cargo.toml
|
||||
├── cortex.toml # gateway config (neurons only)
|
||||
├── models.toml # model catalogue
|
||||
@@ -683,120 +616,58 @@ dnf install cortex # gateway host
|
||||
dnf install helexa-neuron # GPU nodes
|
||||
```
|
||||
|
||||
## 2026-05-18 addendum: candle-native pivot
|
||||
### Phase 11: llama.cpp harness stub
|
||||
|
||||
Phases 11 (llama.cpp harness) and 12 (mistral.rs COPR) below are
|
||||
**superseded**. The project no longer treats mistral.rs or llama.cpp as
|
||||
dependencies — both are conceptually out of scope. neuron becomes a
|
||||
candle-native inference daemon, with `Harness` retained as an
|
||||
internal seam for adding future engines (vision/audio/diffusion) but
|
||||
its only implementation being in-process candle.
|
||||
**Goal:** Prove the harness abstraction works with a second engine.
|
||||
|
||||
The full staged plan for this pivot lives at
|
||||
`~/.claude/plans/create-a-more-aggressive-calm-naur.md`. Summary:
|
||||
**Steps:**
|
||||
1. `crates/neuron/src/harness/llamacpp.rs` — implement the `Harness`
|
||||
trait for llama.cpp's `llama-server`.
|
||||
- `start()` — launch `llama-server` with the correct model path,
|
||||
`--port`, `--n-gpu-layers`, `--tensor-split` args. Track the
|
||||
child process.
|
||||
- `stop()` — send SIGTERM to the child process.
|
||||
- `list_models()` — llama-server serves one model per process, so
|
||||
return a single-element list.
|
||||
- `load_model()` — start a new llama-server process for this model.
|
||||
- `unload_model()` — stop the process.
|
||||
- `inference_endpoint()` — return `http://localhost:{assigned_port}`.
|
||||
2. Port allocation: neuron assigns ports from a range (e.g. 8100-8199)
|
||||
to llama-server instances.
|
||||
3. Register in `HarnessRegistry` when configured:
|
||||
```toml
|
||||
[[harnesses]]
|
||||
name = "llamacpp"
|
||||
binary = "/usr/local/bin/llama-server"
|
||||
port_range = [8100, 8199]
|
||||
```
|
||||
4. Tests: mock llama-server (simple HTTP server returning canned
|
||||
responses), test load/unload/endpoint lifecycle.
|
||||
|
||||
- **Stage 1 (this commit):** delete `mistralrs.rs` and `llamacpp.rs`,
|
||||
scaffold inert `CandleHarness`, drop `endpoint`/`systemd_unit` from
|
||||
`HarnessConfig`, default no-op `start`/`stop` on the `Harness` trait.
|
||||
- **Stages 2–4:** wire up candle model load/unload (quantized Qwen3
|
||||
first), add OpenAI-compatible inference endpoint in neuron, then SSE
|
||||
streaming.
|
||||
- **Stages 5–6:** load-on-activation (default models in config) and
|
||||
unload-on-deactivation (graceful shutdown).
|
||||
- **Stages 7–8:** multi-GPU tensor parallelism and broader model/quant
|
||||
coverage.
|
||||
**Done when:** A model with `harness = "llamacpp"` in `models.toml` can
|
||||
be loaded and served through cortex. Tests pass with mock llama-server.
|
||||
|
||||
Sections of this document that describe mistral.rs HTTP behaviour
|
||||
("mistral.rs API gotchas") are retained as historical context for
|
||||
Phases 1–10 — they document what was true while the project depended
|
||||
on mistral.rs. They do not describe current behaviour.
|
||||
### Phase 12 (lower priority): mistral.rs COPR packaging
|
||||
|
||||
---
|
||||
**Goal:** Fedora RPMs for mistral.rs built against specific CUDA versions.
|
||||
|
||||
### Phase 11 (superseded): llama.cpp harness stub
|
||||
**Steps:**
|
||||
1. `mistralrs-cuda.spec` — RPM spec that clones a pinned mistral.rs git
|
||||
tag, builds with `--features cuda`, links against the system CUDA
|
||||
toolkit. Produces `mistralrs-cuda13-server` (CUDA 13.x / sm_120) and
|
||||
`mistralrs-cuda12-server` (CUDA 12.x / sm_89). Install binary to
|
||||
`/usr/local/bin/mistralrs`.
|
||||
2. COPR build config: enable the NVIDIA CUDA repo as a build dependency.
|
||||
Pin the CUDA toolkit version in `BuildRequires`.
|
||||
3. Gitea Actions or manual workflow: bump the mistral.rs tag in the spec,
|
||||
trigger COPR rebuild.
|
||||
4. neuron's mistralrs harness config references which binary/package
|
||||
provides the mistral.rs binary. neuron could warn at startup if the
|
||||
installed mistral.rs CUDA version doesn't match the discovered driver.
|
||||
|
||||
~~Originally planned as a second engine to prove the harness
|
||||
abstraction.~~ Replaced by the candle harness work in the 2026-05-18
|
||||
addendum above. llama.cpp's any-model/any-hardware breadth is no
|
||||
longer in scope for helexa.
|
||||
**Done when:** `dnf install mistralrs-cuda13-server` on beast provides a
|
||||
working `mistralrs` binary built for Blackwell GPUs. `dnf install
|
||||
mistralrs-cuda12-server` on benjy provides one built for Ada GPUs.
|
||||
|
||||
### Phase 12 (superseded): mistral.rs COPR packaging
|
||||
|
||||
~~Originally planned to ship CUDA-versioned mistral.rs RPMs.~~ Replaced
|
||||
by the candle harness work in the 2026-05-18 addendum above. With
|
||||
mistral.rs out of the dependency tree, there is nothing to package.
|
||||
|
||||
## 2026-05-27 addendum: per-device worker thread
|
||||
|
||||
Replaced the ad-hoc `tokio::task::spawn_blocking` pattern that drove
|
||||
every leader-side CUDA op with one dedicated OS thread per CUDA device,
|
||||
permanently bound to that device's `CudaContext`. All leader-side
|
||||
inference work (GGUF + dense + TP shard load, forward, kv-cache clear,
|
||||
NCCL init/sanity, NCCL all_reduce, VRAM query, model drop) routes
|
||||
through the worker via a `std::sync::mpsc` channel; tensors never
|
||||
escape the worker thread alive. See "Per-device worker thread (neuron)"
|
||||
above and `crates/neuron/src/harness/device_worker/mod.rs` for the
|
||||
canonical narrative.
|
||||
|
||||
Motivated by the 2026-05-26 silent-hang on beast: a CUDA OOM cascade
|
||||
poisoned the device context on whichever spawn_blocking thread caught
|
||||
it, and subsequent requests stalled invisibly on the pool lock. After
|
||||
the refactor, the same failure mode shows up in journalctl as
|
||||
`prefill sample failed; logits unhealthy nan: 248320/248320` followed
|
||||
by `failed, model marked poisoned`. The thread stays alive and rejects
|
||||
subsequent requests at the channel boundary.
|
||||
|
||||
Landed in four PRs:
|
||||
|
||||
- **Phase 1** (`081b532`) — device_worker module + 8 VRAM-query sites
|
||||
route through the worker. CPU build only; smoke on beast confirmed
|
||||
a persistent `cuda-dev-0` thread.
|
||||
- **Phase 2** (`b179204`) — single-GPU forward + clear_kv + drop via
|
||||
the worker. `LoadedModel.arch_handle: Option<ArchHandle>` replaces
|
||||
`Arc<Mutex<ModelArch>>` for CUDA loads. CPU keeps the legacy path.
|
||||
- **Phase 3** (`76ab24d`) — TP forward + NCCL init/sanity + leader
|
||||
KV-clear routed through the worker. `WorkerPool.leader_nccl` moves
|
||||
into the worker's state. `TpLoadedModel.leader_handle: TpHandle`
|
||||
replaces `Arc<Mutex<TpLeaderModel>>`. CUDA-only TP smoke deferred to
|
||||
next deploy.
|
||||
- **Phase 4** (`b4f3576`) — GGUF + dense + TP shard loads move onto
|
||||
the worker. The `Job::TransferIn` / `Job::CloneLeaderComm` bridges
|
||||
from Phases 2/3 deleted; `SendComm` newtype no longer needed in the
|
||||
load path. `grep -rn spawn_blocking crates/neuron/src/harness/`
|
||||
returns only deliberate CPU-fallback hits after this PR.
|
||||
|
||||
## 2026-06-13 addendum: build metadata + helexa-bench
|
||||
|
||||
Two coupled additions so fleet performance can be tracked automatically
|
||||
across neuron updates instead of by hand-running `script/bench.py` and
|
||||
editing `doc/benchmarks.md`.
|
||||
|
||||
**neuron build metadata + `GET /version`.** neuron's `build.rs` now also
|
||||
captures build identity (`HELEXA_GIT_SHA` — preferring a CI/RPM-injected
|
||||
`HELEXA_BUILD_SHA`, falling back to git, else `unknown` — plus dirty
|
||||
flag, build timestamp, rustc version, profile, enabled cargo features,
|
||||
and a best-effort `candle-core` version from `Cargo.lock`). These are
|
||||
exposed as `cortex_core::build_info::BuildInfo` (new module) from a new
|
||||
`GET /version` endpoint (`neuron/src/version.rs`, wired in `api.rs`) and
|
||||
in clap's `--version` long form. The SHA is injected in CI
|
||||
(`build-prerelease.yml` build-neuron step: `export HELEXA_BUILD_SHA=$(git
|
||||
rev-parse HEAD)`) and via `--define helexa_commit` in the source-build
|
||||
spec, so tarball-built RPMs report the real SHA. `/version` is now the
|
||||
canonical "which build is live" probe (supersedes the per-host RPM-sha
|
||||
check in the fleet-validation flow).
|
||||
|
||||
**`crates/helexa-bench`** — a new binary: a continuous, version-aware
|
||||
benchmark harness (one systemd unit, typically on the metrics host). It
|
||||
hits each neuron **directly** on `:13131`, exercises each **warm**
|
||||
(`status == "loaded"`) model with an extensible `Scenario` suite (phase
|
||||
1: the chat-latency family ported verbatim from `bench.py` — synthetic
|
||||
128/4096-tok prompts, `/no_think`, streamed TTFT + decode-window
|
||||
tok/s), and records each run into a SQLite system-of-record stamped with
|
||||
the neuron's full `BuildInfo`. The loop is **version-aware**: it skips
|
||||
any (target, build SHA, model, scenario) cell already at
|
||||
`samples_per_version`, so a steady fleet costs only cheap `/version` +
|
||||
`/models` polls until a new SHA ships. `helexa-bench report` regenerates
|
||||
the `benchmarks.md`-style table from the DB. `kind = "openai"` targets
|
||||
(mistral.rs/llama.cpp comparison) are scaffolded but not yet wired.
|
||||
Packaged as the `helexa-bench` RPM (prebuilt-binary spec, outbound-only
|
||||
so no firewalld service) via the same `build-prerelease.yml` pipeline.
|
||||
This is a separate repo/spec — not part of the cortex workspace — but
|
||||
tightly coupled operationally. Track it as a sibling project.
|
||||
|
||||
2626
Cargo.lock
generated
2626
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
17
Cargo.toml
17
Cargo.toml
@@ -5,15 +5,13 @@ members = [
|
||||
"crates/cortex-gateway",
|
||||
"crates/cortex-cli",
|
||||
"crates/neuron",
|
||||
"crates/helexa-acp",
|
||||
"crates/helexa-bench",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.16"
|
||||
version = "0.1.12"
|
||||
edition = "2024"
|
||||
license = "GPL-3.0-or-later"
|
||||
repository = "https://git.lair.cafe/helexa/helexa"
|
||||
repository = "https://git.lair.cafe/helexa/cortex"
|
||||
|
||||
[workspace.dependencies]
|
||||
# async runtime
|
||||
@@ -29,7 +27,7 @@ serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
toml = "0.8"
|
||||
|
||||
# http client (for proxying to neuron backends)
|
||||
# http client (for proxying to mistralrs backends)
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
|
||||
# observability
|
||||
@@ -62,12 +60,3 @@ eventsource-stream = "0.2"
|
||||
# workspace crates
|
||||
cortex-core = { path = "crates/cortex-core" }
|
||||
cortex-gateway = { path = "crates/cortex-gateway" }
|
||||
|
||||
# Patched cudarc (affects neuron's 0.19.x only; candle's 0.17.x is
|
||||
# untouched since the fork is 0.19.7 and doesn't satisfy a 0.17 req). Adds
|
||||
# Comm::abort / get_async_error / raw comm() — needed for #17 Stage 2 TP
|
||||
# hang-recovery (abort a wedged collective from another thread, then
|
||||
# rebuild the comm). Pinned to a fork revision pending upstream review
|
||||
# (grenade/cudarc @ nccl-comm-abort).
|
||||
[patch.crates-io]
|
||||
cudarc = { git = "https://github.com/grenade/cudarc", rev = "63327a256059f8252641ae46c6bb9eefe707f382" }
|
||||
|
||||
223
README.md
223
README.md
@@ -1,68 +1,24 @@
|
||||
# helexa
|
||||
# cortex
|
||||
|
||||
**Near-frontier AI for mortals.**
|
||||
A Rust reverse-proxy and fleet management layer for multi-node
|
||||
[mistral.rs](https://github.com/EricLBuehler/mistral.rs) inference clusters.
|
||||
|
||||
helexa is a self-hosted LLM serving stack, written in Rust, for people
|
||||
who run open-weight models on their own consumer GPUs. It has two
|
||||
components:
|
||||
## Problem
|
||||
|
||||
- **cortex** — the per-operator control plane and LLM proxy. It sits in
|
||||
front of your GPU fleet and presents a unified OpenAI + Anthropic
|
||||
compatible API surface, handling model routing, lifecycle management
|
||||
(load / unload / evict), request translation, and metrics.
|
||||
- **neuron** — the per-host LLM harness. One instance runs on every GPU
|
||||
host, serving candle-based in-process inference and managing local
|
||||
hardware discovery and model lifecycle.
|
||||
Running local LLMs across multiple GPU nodes (different VRAM tiers, different
|
||||
model affinities) requires a unified API surface that:
|
||||
|
||||
## Why
|
||||
|
||||
Two principles constrain everything in this repository:
|
||||
|
||||
1. **Frontier or close to it.** helexa serves the open-weight models
|
||||
that get nearest to frontier capability — not every architecture
|
||||
ever published.
|
||||
2. **Consumer hardware.** Everything must run on the cards mortals can
|
||||
actually buy: a 3060 here, a 4090 there, a 5090 if you got lucky.
|
||||
Mixed VRAM tiers across mismatched boxes are the expected topology,
|
||||
not a degraded case.
|
||||
|
||||
GPU acquisition is harder than it was a year ago, and the gap between
|
||||
what cloud providers charge and what your own silicon costs keeps
|
||||
widening. The intersection of those two principles — near-frontier
|
||||
models, squeezed onto hardware you own — is helexa's entire niche.
|
||||
|
||||
The secondary objective is **predictable consumption**. If you own the
|
||||
hardware, your tooling shouldn't break because a cloud provider changed
|
||||
billing, deprecated a model, or reshaped an API. cortex's OpenAI and
|
||||
Anthropic surfaces are a stability contract: point your editor, agent,
|
||||
or CLI at it once, and it keeps working.
|
||||
|
||||
## What helexa is not
|
||||
|
||||
This is an intentionally different path from vLLM, SGLang, and peers —
|
||||
not a smaller version of them. Out of scope, permanently:
|
||||
|
||||
- Any-model breadth. Architectures are ported because they're at or
|
||||
near the frontier, not to complete a compatibility matrix.
|
||||
- Datacenter-class scheduling. No sophisticated continuous-batching /
|
||||
paged-attention machinery — the workload is a handful of operators
|
||||
and their agents, not 200 QPS.
|
||||
- Wrapping external inference engines. neuron builds directly on
|
||||
[candle](https://github.com/huggingface/candle); every model
|
||||
architecture it serves is implemented in this repository, ported
|
||||
against the HuggingFace reference.
|
||||
|
||||
One thing that is *not* a principle: CUDA exclusivity. All high-end
|
||||
consumer hardware is in scope. helexa is CUDA-only today because
|
||||
that's the hardware on the bench — nothing ships untested — and ROCm
|
||||
or other consumer accelerators join as soon as there's real hardware
|
||||
to build against.
|
||||
|
||||
In scope, and where the engineering effort goes: aggressive
|
||||
quantization (GGUF Q4_K_M / Q6_K / Q8_0), NCCL tensor parallelism
|
||||
across heterogeneous consumer GPUs, careful CUDA failure handling, and
|
||||
single-request latency — the performance that one operator at a
|
||||
keyboard actually feels.
|
||||
- Presents a **single `/v1/models` catalogue** merging every model across every
|
||||
node.
|
||||
- **Routes requests** to the correct node based on where a model is loaded (or
|
||||
*can* be loaded).
|
||||
- Manages **model lifecycle** — unload cold models, reload on demand, pin
|
||||
critical ones — using the mistral.rs
|
||||
`/v1/models/{unload,reload,status}` HTTP API (PR #1828+).
|
||||
- Translates between **OpenAI and Anthropic** request/response envelopes so
|
||||
every client in the homelab speaks whichever dialect it prefers.
|
||||
- Captures **per-request metrics** (tokens, tok/s, TTFT, latency) and exposes
|
||||
them as Prometheus counters/histograms.
|
||||
|
||||
## Architecture
|
||||
|
||||
@@ -72,79 +28,65 @@ keyboard actually feels.
|
||||
└──────┬───────┘ └─────┬────┘ └──────┬─────┘ └──────┬─────┘
|
||||
│ │ │ │
|
||||
└────────────────┴──────┬───────┴───────────────┘
|
||||
│ OpenAI + Anthropic APIs
|
||||
│
|
||||
┌──────────▼──────────┐
|
||||
│ cortex │
|
||||
│ (cortex-gateway) │
|
||||
│ cortex │
|
||||
│ (cortex-gateway) │
|
||||
│ │
|
||||
│ Router · Metrics │
|
||||
│ Evictor · Translate│
|
||||
└──┬──────┬────────┬──┘
|
||||
│ │ │
|
||||
┌──────────▼┐ ┌──▼─────┐ ┌▼──────────┐
|
||||
│ neuron │ │ neuron │ │ neuron │
|
||||
│ :13131 │ │ :13131 │ │ :13131 │
|
||||
│ candle │ │ candle │ │ candle │
|
||||
│ gpu-large │ │gpu-med │ │ gpu-small │
|
||||
│ mistralrs │ │mistral │ │ mistralrs │
|
||||
│ serve │ │rs serve│ │ serve │
|
||||
│ :8080 │ │ :8080 │ │ :8080 │
|
||||
└───────────┘ └────────┘ └───────────┘
|
||||
private network (.internal)
|
||||
```
|
||||
|
||||
cortex discovers each neuron's hardware (devices, VRAM, compute
|
||||
capability) at runtime and matches it against a model catalogue
|
||||
(`models.toml`) to decide placement: which models fit where, what to
|
||||
evict when VRAM is tight, where to route a request right now. Adding a
|
||||
GPU host to the fleet is one `[[neurons]]` entry — no device specs in
|
||||
config.
|
||||
|
||||
### Crates
|
||||
|
||||
| Crate | Purpose |
|
||||
|---|---|
|
||||
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic envelopes, harness trait, discovery types |
|
||||
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, poller, metrics exporter |
|
||||
| `neuron` | Per-host daemon: GPU discovery, in-process candle inference, NCCL tensor parallelism, model lifecycle API |
|
||||
| `cortex-core` | Shared types: config, node/model state, metrics, OpenAI/Anthropic request/response envelopes |
|
||||
| `cortex-gateway` | Axum HTTP server: proxy, router, evictor, metrics exporter |
|
||||
| `cortex-agent` | Per-node sidecar: polls local mistralrs, reports to gateway, handles restart/defrag |
|
||||
| `cortex-cli` | CLI entrypoint (`cortex serve`, `cortex status`, etc.) |
|
||||
| `helexa-acp` | Agent Client Protocol bridge — connects ACP editors (Zed, etc.) to any OpenAI-compatible endpoint, cortex by default |
|
||||
|
||||
## The engine
|
||||
## Node setup
|
||||
|
||||
neuron runs inference in-process on candle — there is no external
|
||||
inference server to babysit. The parts that earn their keep:
|
||||
Each GPU node runs `mistralrs serve` with a multi-model config. Models are
|
||||
declared but start **unloaded** — mistral.rs lazy-loads on first request and
|
||||
the gateway can explicitly unload/reload via the HTTP API.
|
||||
|
||||
- **Per-device worker threads.** Every CUDA device gets one dedicated
|
||||
OS thread that owns its CUDA context for the daemon's lifetime. All
|
||||
loads, forward passes, KV-cache resets, NCCL collectives, VRAM
|
||||
queries, and unloads route through it; tensors never escape it
|
||||
alive. Context binding is pinned to a known thread, the CUDA `Drop`
|
||||
contract is structurally safe, and a driver error poisons one worker
|
||||
— visibly — instead of hanging the whole process.
|
||||
- **Tensor parallelism on consumer cards.** Megatron-style row/column
|
||||
parallel layers with NCCL all-reduce, spanning the mismatched GPUs
|
||||
you actually have. A step watchdog aborts wedged collectives instead
|
||||
of letting a request hang forever.
|
||||
- **Current model focus: the Qwen3 family** — dense and GGUF-quantized,
|
||||
including the hybrid linear-attention (Gated DeltaNet) generation.
|
||||
Vision support is in progress. Each architecture is ported against
|
||||
its HuggingFace reference implementation.
|
||||
Example node systemd unit:
|
||||
|
||||
See `CLAUDE.md` for design rationale and
|
||||
`crates/neuron/src/harness/device_worker/` for the worker narrative.
|
||||
```ini
|
||||
# /etc/systemd/system/mistralrs.service
|
||||
[Unit]
|
||||
Description=mistral.rs inference server
|
||||
After=network-online.target
|
||||
Wants=network-online.target
|
||||
|
||||
## Install
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart=/usr/local/bin/mistralrs serve \
|
||||
--from-config /etc/mistralrs/config.toml \
|
||||
--port 8080
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
Environment=CUDA_VISIBLE_DEVICES=0,1
|
||||
|
||||
Pre-built RPMs for Fedora:
|
||||
|
||||
```sh
|
||||
dnf copr enable helexa/helexa
|
||||
dnf install cortex # on the gateway host
|
||||
dnf install helexa-neuron # on each GPU host
|
||||
systemctl enable --now cortex # or neuron, respectively
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
## Configure
|
||||
## Gateway config
|
||||
|
||||
```toml
|
||||
# /etc/cortex/cortex.toml
|
||||
# cortex.toml
|
||||
[gateway]
|
||||
listen = "0.0.0.0:31313"
|
||||
metrics_listen = "0.0.0.0:31314"
|
||||
@@ -153,38 +95,35 @@ metrics_listen = "0.0.0.0:31314"
|
||||
strategy = "lru" # lru | priority
|
||||
defrag_after_cycles = 50
|
||||
|
||||
[[neurons]]
|
||||
name = "beast"
|
||||
endpoint = "http://beast.internal:13131"
|
||||
[[nodes]]
|
||||
name = "gpu-large"
|
||||
endpoint = "http://gpu-large.internal:8080"
|
||||
vram_mb = 49_152 # e.g. 2x RTX 4090
|
||||
pinned = ["your-org/large-model"]
|
||||
|
||||
[[neurons]]
|
||||
name = "benjy"
|
||||
endpoint = "http://benjy.internal:13131"
|
||||
[[nodes]]
|
||||
name = "gpu-medium"
|
||||
endpoint = "http://gpu-medium.internal:8080"
|
||||
vram_mb = 24_576 # e.g. RTX 4090
|
||||
pinned = ["your-org/medium-model"]
|
||||
|
||||
[[nodes]]
|
||||
name = "gpu-small"
|
||||
endpoint = "http://gpu-small.internal:8080"
|
||||
vram_mb = 12_288 # e.g. RTX 3060
|
||||
pinned = ["your-org/embedding-model"]
|
||||
```
|
||||
|
||||
Model placement profiles (VRAM requirements, quant, device minimums,
|
||||
pinning) live in `models.toml` — see `models.example.toml`.
|
||||
|
||||
## Run
|
||||
|
||||
```sh
|
||||
# start the gateway
|
||||
cortex serve --config /etc/cortex/cortex.toml
|
||||
|
||||
# check fleet status
|
||||
cortex status
|
||||
|
||||
# one catalogue across every node
|
||||
curl http://localhost:31313/v1/models
|
||||
```
|
||||
|
||||
## Build from source
|
||||
## Building
|
||||
|
||||
```sh
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
CI runs on every push; keep it green locally:
|
||||
## CI
|
||||
|
||||
Every push triggers format, lint, and test checks. Ensure these pass
|
||||
locally before pushing:
|
||||
|
||||
```sh
|
||||
cargo fmt --check --all # must be clean
|
||||
@@ -192,18 +131,20 @@ cargo clippy --workspace -- -D warnings # warnings are errors
|
||||
cargo test --workspace # all tests must pass
|
||||
```
|
||||
|
||||
Tagged releases (`v*`) build SRPMs for `cortex` and `helexa-neuron`
|
||||
and publish to COPR.
|
||||
Tagged releases (`v*`) additionally build an SRPM and publish to COPR.
|
||||
|
||||
## Status
|
||||
## Running
|
||||
|
||||
Pre-1.0 and moving fast. The gateway path (routing, eviction,
|
||||
translation, metrics) is stable and tested; the candle-native engine
|
||||
is under active development — expect the supported-model list to track
|
||||
the open-weight frontier, deliberately narrowly.
|
||||
```sh
|
||||
# start the gateway
|
||||
cortex serve --config cortex.toml
|
||||
|
||||
Development happens at <https://git.lair.cafe/helexa/helexa>;
|
||||
<https://github.com/helexa-ai/helexa> is a read-only mirror.
|
||||
# check fleet status
|
||||
cortex status
|
||||
|
||||
# list all models across nodes
|
||||
curl http://localhost:31313/v1/models
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
# neuron.toml for beast.hanzalova.internal
|
||||
#
|
||||
# 2x RTX 5090 (32 GB each) — TP-2 capable. Pre-warms Qwen3.6-27B with
|
||||
# q5k ISQ across both GPUs at activation, matching the validate-neuron
|
||||
# invocation: `validate-neuron.sh beast.hanzalova.internal
|
||||
# Qwen/Qwen3.6-27B q5k 2`.
|
||||
#
|
||||
# Synced to /etc/neuron/neuron.toml by script/infra-setup.sh. Edits
|
||||
# take effect after the next deploy workflow run restarts the service
|
||||
# (default_models is read at activation).
|
||||
|
||||
port = 13131
|
||||
|
||||
[[harnesses]]
|
||||
name = "candle"
|
||||
|
||||
[harness.candle]
|
||||
|
||||
[[default_models]]
|
||||
model_id = "Qwen/Qwen3.6-27B"
|
||||
harness = "candle"
|
||||
quant = "q6k"
|
||||
tensor_parallel = 2
|
||||
devices = [0, 1]
|
||||
@@ -1,19 +0,0 @@
|
||||
# neuron.toml for benjy.hanzalova.internal
|
||||
#
|
||||
# 1x RTX 4090 (24 GB) — largest single-GPU host on the fleet. Pre-warms
|
||||
# Qwen3-8B (bf16, ~18 GB), leaving ~6 GB for KV cache + activations on
|
||||
# moderate-length contexts.
|
||||
#
|
||||
# Synced to /etc/neuron/neuron.toml by script/infra-setup.sh.
|
||||
|
||||
port = 13131
|
||||
|
||||
[[harnesses]]
|
||||
name = "candle"
|
||||
|
||||
[harness.candle]
|
||||
|
||||
[[default_models]]
|
||||
model_id = "Qwen/Qwen3-8B"
|
||||
harness = "candle"
|
||||
devices = [0]
|
||||
@@ -1,19 +0,0 @@
|
||||
# neuron.toml for quadbrat.hanzalova.internal
|
||||
#
|
||||
# 1x RTX 3060 (12 GB) — small / quantised tier. Pre-warms Qwen3-1.7B
|
||||
# (bf16, ~4 GB), leaving ~7 GB for KV cache so long contexts on a small
|
||||
# model still have plenty of room.
|
||||
#
|
||||
# Synced to /etc/neuron/neuron.toml by script/infra-setup.sh.
|
||||
|
||||
port = 13131
|
||||
|
||||
[[harnesses]]
|
||||
name = "candle"
|
||||
|
||||
[harness.candle]
|
||||
|
||||
[[default_models]]
|
||||
model_id = "Qwen/Qwen3-1.7B"
|
||||
harness = "candle"
|
||||
devices = [0]
|
||||
@@ -1,20 +0,0 @@
|
||||
# Install on the cortex gateway host as /etc/sudoers.d/helexa_gitea_ci
|
||||
# (owner root:root, mode 0440). Required by .gitea/workflows/deploy.yml,
|
||||
# which SSHes as gitea_ci@<gateway> to roll out cortex package upgrades
|
||||
# and config changes.
|
||||
#
|
||||
# Filename convention `helexa_gitea_ci` (vs bare `gitea_ci`) so other
|
||||
# helexa-org apps can drop their own sudoers files on the same host
|
||||
# without overwriting this one.
|
||||
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/rsync * /etc/cortex/cortex.toml
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/rsync * /etc/cortex/models.toml
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl start cortex.service
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl stop cortex.service
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl daemon-reload
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf install --refresh --allowerasing -y cortex
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf upgrade --refresh --allowerasing -y cortex
|
||||
# sudoers reserves `:` and `=` and requires `\` escaping inside command
|
||||
# arguments — without it visudo errors at the first `:` in `https://`.
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf config-manager addrepo --from-repofile\=https\://rpm.lair.cafe/lair-cafe-unstable.repo
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf config-manager setopt lair-cafe-unstable.enabled\=1
|
||||
@@ -1,38 +0,0 @@
|
||||
# Install on every neuron host as /etc/sudoers.d/helexa_gitea_ci
|
||||
# (owner root:root, mode 0440). Required by .gitea/workflows/deploy.yml,
|
||||
# which SSHes as gitea_ci@<neuron-host> to roll out helexa-neuron-<flavour>
|
||||
# package upgrades and config changes.
|
||||
#
|
||||
# Filename convention `helexa_gitea_ci` (vs bare `gitea_ci`) so other
|
||||
# helexa-org apps can drop their own sudoers files on the same host
|
||||
# without overwriting this one.
|
||||
#
|
||||
# All three CUDA flavours are listed because a host's flavour can change
|
||||
# (e.g. GPU swap) and we don't want the sudoers file to need to change
|
||||
# in lockstep. Only one flavour can be installed at a time (the packages
|
||||
# Conflict: with each other), so the attack surface is bounded to "wrong
|
||||
# flavour installed" — vandalism, not privilege escalation.
|
||||
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/rsync * /etc/neuron/neuron.toml
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl start neuron.service
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl stop neuron.service
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl daemon-reload
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf install --refresh --allowerasing -y helexa-neuron-ampere
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf upgrade --refresh --allowerasing -y helexa-neuron-ampere
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf install --refresh --allowerasing -y helexa-neuron-ada
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf upgrade --refresh --allowerasing -y helexa-neuron-ada
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf install --refresh --allowerasing -y helexa-neuron-blackwell
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf upgrade --refresh --allowerasing -y helexa-neuron-blackwell
|
||||
# sudoers reserves `:` and `=` and requires `\` escaping inside command
|
||||
# arguments — without it visudo errors at the first `:` in `https://`.
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf config-manager addrepo --from-repofile\=https\://rpm.lair.cafe/lair-cafe-unstable.repo
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf config-manager setopt lair-cafe-unstable.enabled\=1
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf config-manager addrepo --from-repofile\=https\://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf install -y libcudnn9-cuda-13
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/firewall-cmd --add-service=helexa-neuron --permanent
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/firewall-cmd --reload
|
||||
# deploy-dev.yml fast path: install a freshly-built dev binary over the
|
||||
# packaged one. Exact source path + args; the workflow must use this
|
||||
# command form verbatim. The next deploy.yml run reconciles the host
|
||||
# back to the RPM-owned binary.
|
||||
gitea_ci ALL=(root) NOPASSWD: /usr/bin/install -o root -g root -m 0755 /var/lib/gitea_ci/neuron-dev /usr/bin/neuron
|
||||
@@ -11,14 +11,14 @@ metrics_listen = "0.0.0.0:31314"
|
||||
|
||||
[eviction]
|
||||
strategy = "lru"
|
||||
# Restart neurons after this many load/unload cycles to defragment VRAM.
|
||||
# Restart mistralrs after this many load/unload cycles to defragment VRAM.
|
||||
# Set to 0 to disable.
|
||||
defrag_after_cycles = 50
|
||||
|
||||
# -- Nodes ---------------------------------------------------------------
|
||||
# Each [[nodes]] entry declares a neuron daemon in the fleet.
|
||||
# Models are discovered by polling the neuron's /models endpoint.
|
||||
# Pinned models (see models.toml) are never evicted.
|
||||
# Each [[nodes]] entry declares a mistral.rs instance in the fleet.
|
||||
# Models are discovered by polling the node's /v1/models endpoint.
|
||||
# Pinned models are never evicted.
|
||||
|
||||
[[nodes]]
|
||||
name = "gpu-large"
|
||||
|
||||
43
cortex.spec
43
cortex.spec
@@ -1,10 +1,10 @@
|
||||
Name: cortex
|
||||
Version: 0.1.16
|
||||
Version: 0.1.12
|
||||
Release: 1%{?dist}
|
||||
Summary: Inference gateway for multi-node GPU clusters
|
||||
|
||||
License: GPL-3.0-or-later
|
||||
URL: https://git.lair.cafe/helexa/helexa
|
||||
URL: https://git.lair.cafe/helexa/cortex
|
||||
Source0: %{name}-%{version}.tar.gz
|
||||
Source1: %{name}-%{version}-vendor.tar.gz
|
||||
|
||||
@@ -21,7 +21,6 @@ BuildRequires: systemd-rpm-macros
|
||||
|
||||
Requires(pre): shadow-utils
|
||||
Requires: systemd
|
||||
Requires: firewalld-filesystem
|
||||
|
||||
# systemd-rpm-macros ships a unit dep generator that parses User=/Group=
|
||||
# from our .service file and emits Requires: user(cortex)/group(cortex).
|
||||
@@ -57,7 +56,6 @@ cargo build --release -p cortex-cli
|
||||
install -Dm755 target/release/cortex %{buildroot}%{_bindir}/cortex
|
||||
install -Dm644 data/cortex.service %{buildroot}%{_unitdir}/cortex.service
|
||||
install -Dm644 data/cortex-sysusers.conf %{buildroot}%{_sysusersdir}/cortex.conf
|
||||
install -Dm644 data/cortex-firewalld.xml %{buildroot}%{_prefix}/lib/firewalld/services/cortex.xml
|
||||
install -dm755 %{buildroot}%{_sysconfdir}/cortex
|
||||
install -Dm644 cortex.example.toml %{buildroot}%{_sysconfdir}/cortex/cortex.toml
|
||||
install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
|
||||
@@ -74,53 +72,16 @@ install -Dm644 models.example.toml %{buildroot}%{_sysconfdir}/cortex/models.toml
|
||||
%postun
|
||||
%systemd_postun_with_restart cortex.service
|
||||
|
||||
%posttrans
|
||||
# Migration: older cortex packages shipped the firewalld service as
|
||||
# `helexa-cortex` and (in some build streams) with wrong port numbers
|
||||
# (9301/9302/9304). Operators who enabled that legacy service in their
|
||||
# zone end up with the wrong-port override taking precedence over the
|
||||
# vendor `cortex.xml` now in /usr/lib/firewalld/services/. Clean up the
|
||||
# stale /etc/ override here and migrate any zone bindings to the new
|
||||
# service name.
|
||||
if [ -f /etc/firewalld/services/helexa-cortex.xml ]; then
|
||||
rm -f /etc/firewalld/services/helexa-cortex.xml
|
||||
fi
|
||||
if [ -x /usr/bin/firewall-cmd ] && /usr/bin/firewall-cmd --state >/dev/null 2>&1; then
|
||||
# Drop the legacy service name from every zone where it was enabled
|
||||
# and add the new `cortex` service in its place. Operators who never
|
||||
# ran firewall-cmd against either name see no zone change.
|
||||
for zone in $(/usr/bin/firewall-cmd --get-active-zones 2>/dev/null \
|
||||
| awk '!/^[[:space:]]/ {print $1}'); do
|
||||
if /usr/bin/firewall-cmd --permanent --zone="$zone" --query-service=helexa-cortex >/dev/null 2>&1; then
|
||||
/usr/bin/firewall-cmd --permanent --zone="$zone" --remove-service=helexa-cortex >/dev/null 2>&1 || :
|
||||
/usr/bin/firewall-cmd --permanent --zone="$zone" --add-service=cortex >/dev/null 2>&1 || :
|
||||
fi
|
||||
done
|
||||
/usr/bin/firewall-cmd --reload >/dev/null 2>&1 || :
|
||||
fi
|
||||
:
|
||||
|
||||
%files
|
||||
%license LICENSE
|
||||
%doc README.md
|
||||
%{_bindir}/cortex
|
||||
%{_unitdir}/cortex.service
|
||||
%{_sysusersdir}/cortex.conf
|
||||
%{_prefix}/lib/firewalld/services/cortex.xml
|
||||
%dir %{_sysconfdir}/cortex
|
||||
%config(noreplace) %{_sysconfdir}/cortex/cortex.toml
|
||||
%config(noreplace) %{_sysconfdir}/cortex/models.toml
|
||||
|
||||
%changelog
|
||||
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.16-1
|
||||
- chore: ignore local deploy script
|
||||
- chore: move default ports out of common-collision ranges
|
||||
- ci: drop actions/cache for cargo registry and target
|
||||
|
||||
* Thu Apr 16 2026 Gitea Actions <actions@git.lair.cafe> - 0.1.14-1
|
||||
- ci: publish both packages to a single helexa/helexa COPR project
|
||||
- fix(rpm): rename neuron package to helexa-neuron
|
||||
- ci: commit generated %changelog entries back to main
|
||||
|
||||
* Wed Apr 15 2026 Rob Thijssen <grenade@rob.tn> - 0.1.0-1
|
||||
- Initial package
|
||||
|
||||
@@ -5,7 +5,7 @@ use tracing_subscriber::EnvFilter;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "cortex")]
|
||||
#[command(about = "Unified inference gateway for multi-node GPU clusters")]
|
||||
#[command(about = "Unified inference gateway for multi-node mistral.rs clusters")]
|
||||
#[command(version)]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
//!
|
||||
//! These mirror the `/v1/messages` format used by the Anthropic API.
|
||||
//! The gateway accepts these, translates to OpenAI format, proxies to
|
||||
//! the inference backend (neuron), then translates the response back.
|
||||
//! mistral.rs, then translates the response back.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
//! Build/version metadata shared between cortex and neuron.
|
||||
//!
|
||||
//! neuron captures these facts at compile time in its `build.rs`
|
||||
//! (git SHA, enabled cargo features, rustc/candle versions, …) and
|
||||
//! serves them from `GET /version`. cortex and `helexa-bench`
|
||||
//! deserialize the same struct so a benchmark run can be attributed to
|
||||
//! the exact daemon build that produced it — not just the host's CUDA
|
||||
//! and driver versions that `/discovery` already reports.
|
||||
//!
|
||||
//! Every field beyond the always-present package version is
|
||||
//! `#[serde(default)]` so a newer reader stays compatible with an
|
||||
//! older neuron that omits a field (and vice versa) — the same
|
||||
//! forward/backward-compat discipline as
|
||||
//! [`crate::discovery::ActivationStatus`].
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Build-time identity of a neuron daemon.
|
||||
///
|
||||
/// Returned by `GET /version`. The `git_sha` is the canonical "which
|
||||
/// build is live" key — benchmark records are bucketed by it, so a
|
||||
/// regression can be pinned to a daemon change rather than a host
|
||||
/// change. When neuron is built from a source tarball with no git
|
||||
/// metadata available (and no `HELEXA_BUILD_SHA` injected by CI/RPM),
|
||||
/// `git_sha` is the string `"unknown"`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct BuildInfo {
|
||||
/// Crate version from `CARGO_PKG_VERSION` (e.g. `"0.1.16"`).
|
||||
pub package_version: String,
|
||||
/// Short git SHA, or `"unknown"` when unavailable at build time.
|
||||
#[serde(default = "unknown")]
|
||||
pub git_sha: String,
|
||||
/// Full 40-char git SHA when available.
|
||||
#[serde(default)]
|
||||
pub git_sha_long: Option<String>,
|
||||
/// Whether the working tree had uncommitted changes at build time.
|
||||
/// `false` when the SHA is unknown (tarball build).
|
||||
#[serde(default)]
|
||||
pub git_dirty: bool,
|
||||
/// RFC3339 build timestamp.
|
||||
#[serde(default)]
|
||||
pub build_timestamp: Option<String>,
|
||||
/// `rustc --version` output of the compiler used.
|
||||
#[serde(default)]
|
||||
pub rustc_version: Option<String>,
|
||||
/// Cargo build profile: `"release"` or `"debug"`.
|
||||
#[serde(default)]
|
||||
pub profile: Option<String>,
|
||||
/// Target triple the binary was compiled for.
|
||||
#[serde(default)]
|
||||
pub target: Option<String>,
|
||||
/// Enabled cargo features (e.g. `["cuda", "cudnn"]`). These define
|
||||
/// the performance envelope, so they are recorded against every
|
||||
/// benchmark run.
|
||||
#[serde(default)]
|
||||
pub features: Vec<String>,
|
||||
/// Locked `candle-core` version, best-effort from `Cargo.lock`.
|
||||
#[serde(default)]
|
||||
pub candle_version: Option<String>,
|
||||
}
|
||||
|
||||
fn unknown() -> String {
|
||||
"unknown".to_string()
|
||||
}
|
||||
|
||||
impl BuildInfo {
|
||||
/// A placeholder used by non-neuron benchmark targets (and tests)
|
||||
/// that have no build metadata to report.
|
||||
pub fn unknown() -> Self {
|
||||
BuildInfo {
|
||||
package_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
git_sha: unknown(),
|
||||
git_sha_long: None,
|
||||
git_dirty: false,
|
||||
build_timestamp: None,
|
||||
rustc_version: None,
|
||||
profile: None,
|
||||
target: None,
|
||||
features: Vec::new(),
|
||||
candle_version: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn round_trips_full() {
|
||||
let info = BuildInfo {
|
||||
package_version: "0.1.16".into(),
|
||||
git_sha: "30d50d6".into(),
|
||||
git_sha_long: Some("30d50d6abc123".into()),
|
||||
git_dirty: true,
|
||||
build_timestamp: Some("2026-06-13T10:00:00+00:00".into()),
|
||||
rustc_version: Some("rustc 1.85.0".into()),
|
||||
profile: Some("release".into()),
|
||||
target: Some("x86_64-unknown-linux-gnu".into()),
|
||||
features: vec!["cuda".into(), "cudnn".into()],
|
||||
candle_version: Some("0.10.2".into()),
|
||||
};
|
||||
let json = serde_json::to_string(&info).unwrap();
|
||||
let back: BuildInfo = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(info, back);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserializes_minimal_payload() {
|
||||
// An older neuron might send only the package version; every
|
||||
// other field must default rather than fail.
|
||||
let back: BuildInfo = serde_json::from_str(r#"{"package_version":"0.1.0"}"#).unwrap();
|
||||
assert_eq!(back.package_version, "0.1.0");
|
||||
assert_eq!(back.git_sha, "unknown");
|
||||
assert!(!back.git_dirty);
|
||||
assert!(back.features.is_empty());
|
||||
assert!(back.candle_version.is_none());
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,6 @@
|
||||
//! Model catalogue — profiles describing how to serve each model.
|
||||
|
||||
use crate::discovery::DeviceInfo;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
/// A model serving profile loaded from models.toml.
|
||||
@@ -24,17 +22,6 @@ pub struct ModelProfile {
|
||||
/// Neurons where this model should never be evicted.
|
||||
#[serde(default)]
|
||||
pub pinned_on: Vec<String>,
|
||||
/// Source scheme this profile's weights come from. When set, the
|
||||
/// router prefixes `id` with `scheme:` before forwarding the load
|
||||
/// request to neuron, ensuring the daemon fetches from the right
|
||||
/// registry regardless of which entry happens to match `id`.
|
||||
///
|
||||
/// `None` lets neuron substitute its own `default_source` (typically
|
||||
/// `huggingface`). Set to `"helexa"` when the model is hosted in
|
||||
/// the helexa registry — operator-procurement-grade audit relies
|
||||
/// on this being explicit per model rather than implicit.
|
||||
#[serde(default)]
|
||||
pub source: Option<String>,
|
||||
}
|
||||
|
||||
fn default_min_devices() -> u32 {
|
||||
@@ -46,14 +33,6 @@ fn default_min_devices() -> u32 {
|
||||
pub struct ModelCatalogue {
|
||||
#[serde(default)]
|
||||
pub models: Vec<ModelProfile>,
|
||||
/// Tier aliases — clients can send a request with `model: "helexa/small"`
|
||||
/// and the gateway transparently rewrites + routes to the concrete
|
||||
/// model id this maps to. Lets operators define latency/quality
|
||||
/// tiers (`small`/`balanced`/`large`, `fast`/`thinking`, etc.)
|
||||
/// without imposing knowledge of specific model ids on clients.
|
||||
/// Loaded from the `[aliases]` table in models.toml.
|
||||
#[serde(default)]
|
||||
pub aliases: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl ModelCatalogue {
|
||||
@@ -85,162 +64,4 @@ impl ModelCatalogue {
|
||||
.iter()
|
||||
.any(|p| p.id == model_id && p.pinned_on.contains(&neuron_name.to_string()))
|
||||
}
|
||||
|
||||
/// Find a profile by model id.
|
||||
pub fn get(&self, model_id: &str) -> Option<&ModelProfile> {
|
||||
self.models.iter().find(|p| p.id == model_id)
|
||||
}
|
||||
|
||||
/// Resolve an alias to its concrete model id. Returns `id` verbatim
|
||||
/// when it isn't an alias. Aliases never chain — operator config
|
||||
/// is treated as flat — so this is a single lookup.
|
||||
pub fn resolve_alias<'a>(&'a self, id: &'a str) -> &'a str {
|
||||
self.aliases.get(id).map(String::as_str).unwrap_or(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelProfile {
|
||||
/// True iff this profile's placement constraints can be satisfied
|
||||
/// by the named neuron with the given device topology.
|
||||
///
|
||||
/// Constraints checked:
|
||||
/// - `pinned_on`: non-empty → neuron must be on the list.
|
||||
/// - `min_devices`: neuron must have at least this many devices.
|
||||
/// - `min_device_vram_mb`: at least `min_devices` of the neuron's
|
||||
/// devices must each meet this VRAM floor.
|
||||
pub fn is_feasible_on(&self, neuron_name: &str, devices: &[DeviceInfo]) -> bool {
|
||||
if !self.pinned_on.is_empty() && !self.pinned_on.iter().any(|n| n == neuron_name) {
|
||||
return false;
|
||||
}
|
||||
if (devices.len() as u32) < self.min_devices {
|
||||
return false;
|
||||
}
|
||||
if let Some(min_vram) = self.min_device_vram_mb {
|
||||
let big_enough = devices
|
||||
.iter()
|
||||
.filter(|d| d.vram_total_mb >= min_vram)
|
||||
.count() as u32;
|
||||
if big_enough < self.min_devices {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::discovery::DeviceInfo;
|
||||
|
||||
fn device(idx: u32, vram_mb: u64) -> DeviceInfo {
|
||||
DeviceInfo {
|
||||
index: idx,
|
||||
name: format!("DEV-{idx}"),
|
||||
vram_total_mb: vram_mb,
|
||||
compute_capability: "8.6".into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn profile() -> ModelProfile {
|
||||
ModelProfile {
|
||||
id: "Qwen/Qwen3.6-27B".into(),
|
||||
harness: "candle".into(),
|
||||
quant: None,
|
||||
vram_mb: Some(45_000),
|
||||
min_devices: 2,
|
||||
min_device_vram_mb: Some(24_000),
|
||||
pinned_on: vec![],
|
||||
source: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn feasible_when_two_devices_meet_vram_floor() {
|
||||
let p = profile();
|
||||
let devices = [device(0, 32_000), device(1, 32_000)];
|
||||
assert!(p.is_feasible_on("beast", &devices));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infeasible_when_only_one_device() {
|
||||
let p = profile();
|
||||
let devices = [device(0, 64_000)];
|
||||
assert!(!p.is_feasible_on("benjy", &devices));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infeasible_when_one_device_underspec() {
|
||||
let p = profile();
|
||||
let devices = [device(0, 32_000), device(1, 12_000)];
|
||||
assert!(!p.is_feasible_on("mixed", &devices));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pinned_on_excludes_other_neurons() {
|
||||
let mut p = profile();
|
||||
p.pinned_on = vec!["beast".into()];
|
||||
let devices = [device(0, 32_000), device(1, 32_000)];
|
||||
assert!(p.is_feasible_on("beast", &devices));
|
||||
assert!(!p.is_feasible_on("benjy", &devices));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_vram_floor_just_needs_min_devices() {
|
||||
let mut p = profile();
|
||||
p.min_device_vram_mb = None;
|
||||
let devices = [device(0, 1_000), device(1, 1_000)];
|
||||
assert!(p.is_feasible_on("anywhere", &devices));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_alias_returns_target_when_alias_present() {
|
||||
let mut cat = ModelCatalogue::default();
|
||||
cat.aliases
|
||||
.insert("helexa/small".into(), "Qwen/Qwen3-1.7B".into());
|
||||
assert_eq!(cat.resolve_alias("helexa/small"), "Qwen/Qwen3-1.7B");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_alias_passes_through_when_not_an_alias() {
|
||||
let mut cat = ModelCatalogue::default();
|
||||
cat.aliases
|
||||
.insert("helexa/small".into(), "Qwen/Qwen3-1.7B".into());
|
||||
assert_eq!(cat.resolve_alias("Qwen/Qwen3-8B"), "Qwen/Qwen3-8B");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn source_defaults_to_none_when_absent_from_toml() {
|
||||
let src = r#"
|
||||
[[models]]
|
||||
id = "Qwen/Qwen3-30B"
|
||||
harness = "candle"
|
||||
"#;
|
||||
let cat: ModelCatalogue = toml::from_str(src).expect("parse models table");
|
||||
assert!(cat.models[0].source.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn source_round_trips_through_toml() {
|
||||
let src = r#"
|
||||
[[models]]
|
||||
id = "Helexa/Qwen3.6-27B-Uncensored"
|
||||
harness = "candle"
|
||||
source = "helexa"
|
||||
"#;
|
||||
let cat: ModelCatalogue = toml::from_str(src).expect("parse models table");
|
||||
assert_eq!(cat.models[0].source.as_deref(), Some("helexa"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aliases_table_round_trips_through_toml() {
|
||||
let src = r#"
|
||||
[aliases]
|
||||
"helexa/small" = "Qwen/Qwen3-1.7B"
|
||||
"helexa/large" = "Qwen/Qwen3.6-27B"
|
||||
"#;
|
||||
let cat: ModelCatalogue = toml::from_str(src).expect("parse aliases table");
|
||||
assert_eq!(cat.resolve_alias("helexa/small"), "Qwen/Qwen3-1.7B");
|
||||
assert_eq!(cat.resolve_alias("helexa/large"), "Qwen/Qwen3.6-27B");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,17 +22,6 @@ pub struct DiscoveryResponse {
|
||||
pub driver_version: Option<String>,
|
||||
pub devices: Vec<DeviceInfo>,
|
||||
pub harnesses: Vec<String>,
|
||||
/// Set when the host has an NVIDIA stack that is currently
|
||||
/// unusable — specifically the userspace↔kernel-module version
|
||||
/// skew after an un-rebooted driver update ("Driver/library
|
||||
/// version mismatch"), where every CUDA call including nvidia-smi
|
||||
/// fails (#19). `None` on healthy hosts AND on hosts with no
|
||||
/// NVIDIA stack at all (CPU-only is not an error). Carries an
|
||||
/// operator-actionable description; cortex can read it to route
|
||||
/// around the node instead of cold-loading into a guaranteed
|
||||
/// failure.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub cuda_unavailable_reason: Option<String>,
|
||||
}
|
||||
|
||||
/// Runtime health metrics for a single GPU device.
|
||||
@@ -47,72 +36,8 @@ pub struct DeviceHealth {
|
||||
|
||||
/// Runtime health response from a neuron endpoint.
|
||||
/// Returned by `GET /health`.
|
||||
///
|
||||
/// `activation` was added in 2026-05-26 to distinguish "process is up
|
||||
/// and reachable" from "process is ready to serve traffic". A `Type=simple`
|
||||
/// systemd unit reports `active` the moment the binary starts — but a
|
||||
/// neuron whose `default_models` list takes minutes to materialise
|
||||
/// won't bind its listener (or, in the new flow, won't have any models
|
||||
/// loaded) until pre-warm completes. The new field is `#[serde(default)]`
|
||||
/// so a pre-2026-05-26 gateway polling a new neuron — or vice versa —
|
||||
/// keeps working.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HealthResponse {
|
||||
pub uptime_secs: u64,
|
||||
pub devices: Vec<DeviceHealth>,
|
||||
#[serde(default)]
|
||||
pub activation: ActivationStatus,
|
||||
}
|
||||
|
||||
/// High-level activation state of the neuron daemon. The HTTP listener
|
||||
/// is bound during both states; what differs is whether the configured
|
||||
/// `default_models` have finished loading.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ActivationState {
|
||||
/// At least one `default_models` entry is still loading. The
|
||||
/// neuron's other endpoints work, but inference against
|
||||
/// not-yet-loaded models will 404.
|
||||
PreWarming,
|
||||
/// Every `default_models` entry has either loaded or failed; the
|
||||
/// neuron is steady-state. Subsequent on-demand loads via
|
||||
/// `/models/load` don't flip back to PreWarming — that field
|
||||
/// reflects the activation-time set only.
|
||||
#[default]
|
||||
Ready,
|
||||
}
|
||||
|
||||
/// Per-model failure record surfaced in [`ActivationStatus::failed`].
|
||||
/// The error string is the rendered anyhow chain at the time of the
|
||||
/// failure; operators read it from `/health` to decide whether to
|
||||
/// retry, edit the spec, or unload+reload.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PreWarmFailure {
|
||||
pub model_id: String,
|
||||
pub error: String,
|
||||
}
|
||||
|
||||
/// Activation-time progress snapshot. All four lists are populated by
|
||||
/// the neuron's pre-warm task and read by the `/health` handler. The
|
||||
/// snapshot is consistent: a model id appears in exactly one of
|
||||
/// `pending`, `in_progress` (as `Option<String>`), `completed`, or
|
||||
/// `failed` at any point in time.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ActivationStatus {
|
||||
pub state: ActivationState,
|
||||
/// Model ids queued but not yet started. Empty in `Ready` state.
|
||||
#[serde(default)]
|
||||
pub pending: Vec<String>,
|
||||
/// Model id currently materialising. None when between models or
|
||||
/// in `Ready` state.
|
||||
#[serde(default)]
|
||||
pub in_progress: Option<String>,
|
||||
/// Model ids that finished loading successfully during this
|
||||
/// activation. Cleared on process restart.
|
||||
#[serde(default)]
|
||||
pub completed: Vec<String>,
|
||||
/// Model ids that failed during this activation, with the rendered
|
||||
/// error chain. Cleared on process restart.
|
||||
#[serde(default)]
|
||||
pub failed: Vec<PreWarmFailure>,
|
||||
}
|
||||
|
||||
@@ -9,13 +9,13 @@ use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration for a harness instance on a neuron.
|
||||
///
|
||||
/// All current harnesses are in-process (candle); per-harness tuning
|
||||
/// (cache paths, device policies, etc.) lives in dedicated config
|
||||
/// blocks rather than on this struct.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HarnessConfig {
|
||||
pub name: String,
|
||||
/// Base URL of the harness (e.g. "http://localhost:8080" for mistral.rs).
|
||||
pub endpoint: Option<String>,
|
||||
/// Systemd unit name, if the harness is managed via systemd.
|
||||
pub systemd_unit: Option<String>,
|
||||
}
|
||||
|
||||
/// Health status of a harness process.
|
||||
@@ -44,37 +44,19 @@ pub struct ModelInfo {
|
||||
pub status: String,
|
||||
pub devices: Vec<u32>,
|
||||
pub vram_used_mb: Option<u64>,
|
||||
/// Modalities this loaded model supports. Today: `["text"]` for
|
||||
/// text-only checkpoints, `["text", "vision"]` for vision-capable
|
||||
/// ones (Stage B7 of the vision plan). Clients like litellm /
|
||||
/// agent0 can gate `image_url` submission on the advertised set.
|
||||
///
|
||||
/// Optional in the wire format so older clients that don't read
|
||||
/// it stay compatible. Default-empty for absent/older data, which
|
||||
/// callers can interpret as "text".
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub capabilities: Vec<String>,
|
||||
}
|
||||
|
||||
/// What an inference harness must do, from neuron's perspective.
|
||||
///
|
||||
/// All current harnesses are in-process — they share neuron's address
|
||||
/// space and lifecycle. `start`/`stop` therefore default to no-ops; a
|
||||
/// future process-supervising harness would override them.
|
||||
#[async_trait]
|
||||
pub trait Harness: Send + Sync {
|
||||
/// Human-readable name (e.g. "candle").
|
||||
/// Human-readable name (e.g. "mistralrs", "llamacpp", "comfyui").
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Start the harness. Default no-op for in-process harnesses.
|
||||
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
/// Start the harness process if it is not already running.
|
||||
async fn start(&self, config: &HarnessConfig) -> Result<()>;
|
||||
|
||||
/// Stop the harness. Default no-op for in-process harnesses.
|
||||
async fn stop(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
/// Stop the harness process gracefully.
|
||||
async fn stop(&self) -> Result<()>;
|
||||
|
||||
/// Health check. Returns the harness process status.
|
||||
async fn health(&self) -> HarnessHealth;
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
pub mod anthropic;
|
||||
pub mod build_info;
|
||||
pub mod catalogue;
|
||||
pub mod config;
|
||||
pub mod discovery;
|
||||
@@ -7,6 +6,4 @@ pub mod harness;
|
||||
pub mod metrics;
|
||||
pub mod node;
|
||||
pub mod openai;
|
||||
pub mod responses;
|
||||
pub mod source;
|
||||
pub mod translate;
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use crate::discovery::{ActivationStatus, DiscoveryResponse};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
@@ -14,18 +13,6 @@ pub struct NodeState {
|
||||
/// Number of load/unload cycles since last process restart.
|
||||
pub lifecycle_cycles: u32,
|
||||
pub last_poll: Option<DateTime<Utc>>,
|
||||
/// Result of the most recent successful `GET /discovery` against
|
||||
/// this neuron. Cached forever once obtained — device topology is
|
||||
/// invariant for a given neuron process. `None` until the first
|
||||
/// successful poll. Used by the router and `/v1/models` to do
|
||||
/// catalogue × topology feasibility checks.
|
||||
pub discovery: Option<DiscoveryResponse>,
|
||||
/// Last-seen pre-warm progress from this neuron's `/health`
|
||||
/// endpoint. `None` until the first /health poll succeeds. The
|
||||
/// `/v1/models` handler reads `in_progress` + `pending` from here
|
||||
/// to synthesize `Loading` locations so clients see a catalogued
|
||||
/// model that's mid-prewarm as "loading", not "missing".
|
||||
pub activation: Option<ActivationStatus>,
|
||||
}
|
||||
|
||||
/// A model registered on a node, with its runtime status.
|
||||
@@ -37,72 +24,25 @@ pub struct ModelEntry {
|
||||
pub last_accessed: Option<DateTime<Utc>>,
|
||||
/// Estimated VRAM usage in MB when loaded.
|
||||
pub vram_estimate_mb: Option<u64>,
|
||||
/// Modalities the loaded model advertises (e.g. `["text", "vision"]`),
|
||||
/// copied verbatim from the neuron's `ModelInfo.capabilities` at poll
|
||||
/// time. Empty when the neuron reports none. `#[serde(default)]` keeps
|
||||
/// older persisted/serialised entries deserialisable.
|
||||
#[serde(default)]
|
||||
pub capabilities: Vec<String>,
|
||||
}
|
||||
|
||||
/// Model lifecycle status.
|
||||
///
|
||||
/// `Loading` is a gateway-side synthetic status: neurons never emit it
|
||||
/// on `/models` (that endpoint only knows about already-loaded handles).
|
||||
/// The gateway populates it from a neuron's `/health` activation
|
||||
/// snapshot so the unified `/v1/models` can distinguish "model is
|
||||
/// catalogued but no one has it" from "model is materialising on
|
||||
/// neuron N right now". Other status values are reported verbatim by
|
||||
/// neurons.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ModelStatus {
|
||||
Loaded,
|
||||
Unloaded,
|
||||
Reloading,
|
||||
Loading,
|
||||
/// Reported by neuron while a poisoned model auto-recovers via
|
||||
/// unload→reload (#17/#20). Temporarily unservable but NOT
|
||||
/// evicted: the gateway holds the route, answers with a transient
|
||||
/// retry error instead of 404, and must not race a second
|
||||
/// placement elsewhere.
|
||||
Recovering,
|
||||
}
|
||||
|
||||
/// Unified model entry as exposed by the gateway's `/v1/models` endpoint.
|
||||
///
|
||||
/// The first four fields (`id`, `object`, `created`, `owned_by`) match
|
||||
/// OpenAI's `/v1/models` shape verbatim, so existing OpenAI-aware
|
||||
/// tooling deserialises this without custom code. The remaining fields
|
||||
/// are helexa-specific extensions — OpenAI clients ignore unknown
|
||||
/// fields and other consumers can read them for placement / debugging.
|
||||
/// Includes which node(s) host this model and their status.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CortexModelEntry {
|
||||
pub id: String,
|
||||
/// Always `"model"` per OpenAI's contract.
|
||||
pub object: String,
|
||||
/// Unix-second timestamp; cortex stamps this at response time.
|
||||
pub created: u64,
|
||||
/// OpenAI's "publisher" field — `"helexa"` for everything we serve.
|
||||
pub owned_by: String,
|
||||
/// True if any neuron currently has this model loaded. False for
|
||||
/// catalogue entries that are feasible but not yet loaded.
|
||||
pub loaded: bool,
|
||||
/// Neurons whose discovered topology can satisfy this model's
|
||||
/// catalogue placement constraints. Empty for models that are
|
||||
/// loaded somewhere but not present in the catalogue (cortex has
|
||||
/// no feasibility opinion on those).
|
||||
pub feasible_on: Vec<String>,
|
||||
/// Where this model is actually loaded right now. Subset of (or
|
||||
/// disjoint from) `feasible_on` depending on whether the catalogue
|
||||
/// covers this model.
|
||||
/// Which nodes have this model (and their status).
|
||||
pub locations: Vec<ModelLocation>,
|
||||
/// Union of the modalities advertised by every neuron that has this
|
||||
/// model loaded (e.g. `["text", "vision"]`). Empty for catalogue-only
|
||||
/// entries with no loaded location — the catalogue profile doesn't
|
||||
/// declare capabilities yet (tracked separately from C3).
|
||||
#[serde(default)]
|
||||
pub capabilities: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//! These are a subset sufficient for chat completions (streaming + non-streaming).
|
||||
//! Fields not relevant to proxying are captured as `serde_json::Value` via
|
||||
//! `#[serde(flatten)]` so we forward them without needing to enumerate every
|
||||
//! extension field a backend might support.
|
||||
//! extension field mistral.rs supports.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
@@ -22,7 +22,7 @@ pub struct ChatCompletionRequest {
|
||||
pub max_tokens: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream: Option<bool>,
|
||||
/// All other fields (tools, response_format, backend extensions, etc.)
|
||||
/// All other fields (tools, response_format, mistral.rs extensions, etc.)
|
||||
#[serde(flatten)]
|
||||
pub extra: Value,
|
||||
}
|
||||
@@ -71,18 +71,10 @@ pub struct ChatCompletionChoice {
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionChunk {
|
||||
#[serde(default)]
|
||||
pub id: String,
|
||||
#[serde(default)]
|
||||
pub object: String,
|
||||
#[serde(default)]
|
||||
pub created: u64,
|
||||
// Lenient deserialization throughout: the gateway parses chunks
|
||||
// from arbitrary OpenAI-compatible upstreams, and some engines
|
||||
// omit fields on special frames (e.g. usage-only final chunks).
|
||||
#[serde(default)]
|
||||
pub model: String,
|
||||
#[serde(default)]
|
||||
pub choices: Vec<ChunkChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<Usage>,
|
||||
|
||||
@@ -1,346 +0,0 @@
|
||||
//! OpenAI Responses API (`POST /v1/responses`) envelope types.
|
||||
//!
|
||||
//! This is OpenAI's newer chat surface, distinct from
|
||||
//! `/v1/chat/completions` in three ways that matter for us:
|
||||
//!
|
||||
//! 1. **Input shape**. Instead of a `messages` array, the request
|
||||
//! carries `input` — either a plain string (single user turn)
|
||||
//! or an array of typed items (messages, function calls,
|
||||
//! function-call outputs, reasoning blocks, …).
|
||||
//! 2. **Output shape**. The response carries a single `output`
|
||||
//! array of items, each typed. We always emit one
|
||||
//! `OutputItem::Message` containing the assistant's reply (plus,
|
||||
//! when we get there, separate `function_call` items).
|
||||
//! 3. **Streaming events**. Where chat completions stream
|
||||
//! structurally-identical `chat.completion.chunk` frames over
|
||||
//! `data:` lines, Responses streams *named* events
|
||||
//! (`response.created`, `response.output_text.delta`,
|
||||
//! `response.completed`, …) over `event:` + `data:` SSE pairs.
|
||||
//! The wire projector in `neuron::wire::openai_responses` builds
|
||||
//! these from the same [`crate::openai`]-shaped
|
||||
//! `InferenceEvent` stream the chat projector consumes.
|
||||
//!
|
||||
//! Scope cuts for this first cut:
|
||||
//!
|
||||
//! - **`previous_response_id` is rejected at parse time**. Stateful
|
||||
//! chained conversations need a persistence layer we don't have.
|
||||
//! - **Reasoning items are accepted-and-ignored** (no Qwen3
|
||||
//! `<think>` routing yet). Audio and embedded resources are
|
||||
//! rejected as unsupported.
|
||||
//! - **Tool calls** (function_call / function_call_output) are
|
||||
//! carried as round-trip types but the candle harness doesn't
|
||||
//! emit them yet — wired so the surface is in place for the
|
||||
//! day we add proper tool-call extraction.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// ── Request ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Body of a `POST /v1/responses` request.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResponsesRequest {
|
||||
pub model: String,
|
||||
pub input: ResponsesInput,
|
||||
/// System-prompt-style instructions. The Responses API
|
||||
/// separates these from input so a caller doesn't have to
|
||||
/// build a `system` message item by hand.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub instructions: Option<String>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub max_output_tokens: Option<u64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f64>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f64>,
|
||||
/// Chained-conversation identifier. We don't store responses
|
||||
/// server-side yet; if this is `Some`, the handler returns 400.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub previous_response_id: Option<String>,
|
||||
/// Catch-all for anything we don't model yet (tools, tool_choice,
|
||||
/// reasoning, response_format, …). Lets a client send a
|
||||
/// forward-compatible request without our parser rejecting it.
|
||||
#[serde(flatten)]
|
||||
pub extra: Value,
|
||||
}
|
||||
|
||||
/// `input` is either a single string or an array of typed items.
|
||||
/// `#[serde(untagged)]` so the wire shape `"input": "hi"` and
|
||||
/// `"input": [{...}]` both deserialize.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ResponsesInput {
|
||||
Text(String),
|
||||
Items(Vec<ResponsesInputItem>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponsesInputItem {
|
||||
/// A user / assistant / system turn.
|
||||
Message {
|
||||
role: String,
|
||||
content: ResponsesMessageContent,
|
||||
},
|
||||
/// Assistant emitted a tool call. Round-trip only — neuron
|
||||
/// doesn't synthesise these yet.
|
||||
FunctionCall {
|
||||
call_id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
},
|
||||
/// User is feeding a tool result back into the model.
|
||||
FunctionCallOutput { call_id: String, output: String },
|
||||
/// Reasoning items emitted by o-series models. Accepted but
|
||||
/// not forwarded to the model — neuron's candle path doesn't
|
||||
/// surface reasoning separately yet.
|
||||
Reasoning {
|
||||
#[serde(default)]
|
||||
content: Vec<Value>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Inside a `Message` item, content is either a plain string or an
|
||||
/// array of typed parts. Mirrors the chat-completions Parts shape.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ResponsesMessageContent {
|
||||
Text(String),
|
||||
Parts(Vec<ResponsesContentPart>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponsesContentPart {
|
||||
/// Plain text inside a user / system turn.
|
||||
InputText { text: String },
|
||||
/// An image. `image_url` is either a remote URL or a
|
||||
/// `data:image/png;base64,…` URI; the request translator just
|
||||
/// forwards the string.
|
||||
InputImage {
|
||||
image_url: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
detail: Option<String>,
|
||||
},
|
||||
/// Returned text inside an assistant turn — only relevant when
|
||||
/// the caller is feeding an assistant turn back in to continue
|
||||
/// a conversation manually (no `previous_response_id`).
|
||||
OutputText {
|
||||
text: String,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
annotations: Vec<Value>,
|
||||
},
|
||||
}
|
||||
|
||||
// ── Response (non-streaming) ─────────────────────────────────────────
|
||||
|
||||
/// Body of a `POST /v1/responses` response.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResponsesResponse {
|
||||
pub id: String,
|
||||
/// Always `"response"`.
|
||||
pub object: String,
|
||||
pub created_at: u64,
|
||||
/// `"completed"`, `"incomplete"`, or — for the initial event of
|
||||
/// a streaming response — `"in_progress"`.
|
||||
pub status: String,
|
||||
pub model: String,
|
||||
pub output: Vec<ResponsesOutputItem>,
|
||||
/// Populated on completion; `None` while streaming.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<ResponsesUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponsesOutputItem {
|
||||
Message {
|
||||
id: String,
|
||||
/// Always `"assistant"` for model output.
|
||||
role: String,
|
||||
/// Output content parts. We always emit a single
|
||||
/// `OutputText` today; multi-part output would land here
|
||||
/// once we have e.g. image generation.
|
||||
content: Vec<ResponsesOutputContent>,
|
||||
/// Item-level status. `"in_progress"` while streaming the
|
||||
/// content parts, `"completed"` when done.
|
||||
#[serde(default = "default_item_status")]
|
||||
status: String,
|
||||
},
|
||||
/// Reserved for the day tool-call extraction lands. The wire
|
||||
/// shape mirrors `ResponsesInputItem::FunctionCall`.
|
||||
FunctionCall {
|
||||
id: String,
|
||||
call_id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
#[serde(default = "default_item_status")]
|
||||
status: String,
|
||||
},
|
||||
}
|
||||
|
||||
fn default_item_status() -> String {
|
||||
"completed".into()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponsesOutputContent {
|
||||
OutputText {
|
||||
text: String,
|
||||
/// Citations / inline annotations. Empty today; reserved
|
||||
/// for the day we wire in web search / file search.
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
annotations: Vec<Value>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResponsesUsage {
|
||||
pub input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
pub total_tokens: u64,
|
||||
}
|
||||
|
||||
// ── Streaming event names ────────────────────────────────────────────
|
||||
|
||||
/// Event names the SSE projector emits, hoisted as constants so
|
||||
/// the projector and the wire shape stay in sync without
|
||||
/// string-typos. The strings are dictated by OpenAI's published
|
||||
/// Responses API.
|
||||
pub mod events {
|
||||
pub const CREATED: &str = "response.created";
|
||||
/// Fired between `response.created` and the first output-item
|
||||
/// event. Marks "request validated, model is generating" —
|
||||
/// some clients use it to differentiate the "warming up" state
|
||||
/// from "streaming tokens" in their UI.
|
||||
pub const IN_PROGRESS: &str = "response.in_progress";
|
||||
pub const OUTPUT_ITEM_ADDED: &str = "response.output_item.added";
|
||||
pub const CONTENT_PART_ADDED: &str = "response.content_part.added";
|
||||
pub const OUTPUT_TEXT_DELTA: &str = "response.output_text.delta";
|
||||
pub const OUTPUT_TEXT_DONE: &str = "response.output_text.done";
|
||||
pub const CONTENT_PART_DONE: &str = "response.content_part.done";
|
||||
pub const OUTPUT_ITEM_DONE: &str = "response.output_item.done";
|
||||
pub const COMPLETED: &str = "response.completed";
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn deserialises_input_string_form() {
|
||||
let raw = r#"{"model": "m", "input": "hello"}"#;
|
||||
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
|
||||
match req.input {
|
||||
ResponsesInput::Text(s) => assert_eq!(s, "hello"),
|
||||
other => panic!("expected Text, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialises_input_items_form() {
|
||||
let raw = r#"{
|
||||
"model": "m",
|
||||
"input": [
|
||||
{"type": "message", "role": "user", "content": "hi"}
|
||||
]
|
||||
}"#;
|
||||
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
|
||||
match req.input {
|
||||
ResponsesInput::Items(items) => {
|
||||
assert_eq!(items.len(), 1);
|
||||
match &items[0] {
|
||||
ResponsesInputItem::Message { role, content } => {
|
||||
assert_eq!(role, "user");
|
||||
match content {
|
||||
ResponsesMessageContent::Text(t) => assert_eq!(t, "hi"),
|
||||
other => panic!("expected Text content, got {other:?}"),
|
||||
}
|
||||
}
|
||||
other => panic!("expected Message item, got {other:?}"),
|
||||
}
|
||||
}
|
||||
other => panic!("expected Items, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialises_input_with_image() {
|
||||
let raw = r#"{
|
||||
"model": "m",
|
||||
"input": [
|
||||
{"type": "message", "role": "user", "content": [
|
||||
{"type": "input_text", "text": "what is this"},
|
||||
{"type": "input_image", "image_url": "data:image/png;base64,AAA="}
|
||||
]}
|
||||
]
|
||||
}"#;
|
||||
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
|
||||
let items = match req.input {
|
||||
ResponsesInput::Items(i) => i,
|
||||
other => panic!("expected Items, got {other:?}"),
|
||||
};
|
||||
let parts = match &items[0] {
|
||||
ResponsesInputItem::Message {
|
||||
content: ResponsesMessageContent::Parts(p),
|
||||
..
|
||||
} => p,
|
||||
other => panic!("expected Parts, got {other:?}"),
|
||||
};
|
||||
assert_eq!(parts.len(), 2);
|
||||
assert!(matches!(
|
||||
&parts[0],
|
||||
ResponsesContentPart::InputText { text } if text == "what is this"
|
||||
));
|
||||
assert!(matches!(
|
||||
&parts[1],
|
||||
ResponsesContentPart::InputImage { image_url, .. }
|
||||
if image_url == "data:image/png;base64,AAA="
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_fields_round_trip_via_extra() {
|
||||
let raw = r#"{
|
||||
"model": "m",
|
||||
"input": "hi",
|
||||
"tools": [{"type": "web_search"}],
|
||||
"reasoning": {"effort": "medium"}
|
||||
}"#;
|
||||
let req: ResponsesRequest = serde_json::from_str(raw).unwrap();
|
||||
assert!(req.extra.get("tools").is_some());
|
||||
assert!(req.extra.get("reasoning").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trips_through_serde() {
|
||||
let r = ResponsesResponse {
|
||||
id: "resp_1".into(),
|
||||
object: "response".into(),
|
||||
created_at: 1700,
|
||||
status: "completed".into(),
|
||||
model: "m".into(),
|
||||
output: vec![ResponsesOutputItem::Message {
|
||||
id: "msg_1".into(),
|
||||
role: "assistant".into(),
|
||||
content: vec![ResponsesOutputContent::OutputText {
|
||||
text: "hi there".into(),
|
||||
annotations: vec![],
|
||||
}],
|
||||
status: "completed".into(),
|
||||
}],
|
||||
usage: Some(ResponsesUsage {
|
||||
input_tokens: 5,
|
||||
output_tokens: 3,
|
||||
total_tokens: 8,
|
||||
}),
|
||||
};
|
||||
let json = serde_json::to_string(&r).unwrap();
|
||||
let parsed: ResponsesResponse = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.id, "resp_1");
|
||||
assert_eq!(parsed.output.len(), 1);
|
||||
}
|
||||
}
|
||||
@@ -1,267 +0,0 @@
|
||||
//! Scheme-qualified model identifiers.
|
||||
//!
|
||||
//! cortex/neuron historically resolves every model id through hf-hub
|
||||
//! against `https://huggingface.co`. Helexa is adding an EU-hosted
|
||||
//! registry (`registry.helexa.ai`) alongside HF — both speak the same
|
||||
//! HF-compatible wire format, but the bytes, jurisdiction, and trust
|
||||
//! root differ. Model ids therefore need a scheme:
|
||||
//!
|
||||
//! - `huggingface:Qwen/Qwen3.6-27B` — HF-hosted bytes
|
||||
//! - `helexa:Qwen/Qwen3.6-27B-Uncensored` — helexa registry bytes
|
||||
//! - `helexa:SomeOperator/CustomFinetune` — operator publishing
|
||||
//! under the helexa namespace; same scheme handles all `org/name`
|
||||
//! pairs hosted in that registry.
|
||||
//!
|
||||
//! Bare `org/name` parses with an empty scheme; the caller (typically
|
||||
//! a harness) substitutes its configured default scheme so existing
|
||||
//! configs keep working through the transition.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
|
||||
/// Parsed `scheme:org/name`. Bare `org/name` produces an empty scheme
|
||||
/// — call `with_default_scheme` (or check `is_scheme_unset`) to
|
||||
/// resolve before using.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct ModelSourceId {
|
||||
pub scheme: String,
|
||||
pub org: String,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
/// Errors from `ModelSourceId::from_str`. Carries the offending input
|
||||
/// so log lines / API errors can echo what the operator typed.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
|
||||
pub enum ParseError {
|
||||
#[error("empty model id")]
|
||||
Empty,
|
||||
#[error("model id '{0}' is missing the '/' between org and name")]
|
||||
MissingSlash(String),
|
||||
#[error("model id '{0}' has an empty scheme before ':'")]
|
||||
EmptyScheme(String),
|
||||
#[error("model id '{0}' has an empty org")]
|
||||
EmptyOrg(String),
|
||||
#[error("model id '{0}' has an empty name")]
|
||||
EmptyName(String),
|
||||
#[error("model id '{0}' has a scheme containing '/' which is reserved for org/name")]
|
||||
SchemeContainsSlash(String),
|
||||
#[error("model id '{0}' has a name containing ':' which is reserved for the scheme prefix")]
|
||||
NameContainsColon(String),
|
||||
}
|
||||
|
||||
impl ModelSourceId {
|
||||
/// Construct directly from already-validated parts. Used by tests
|
||||
/// and call sites that have the fields separately; the public API
|
||||
/// for parsing user input is `FromStr`.
|
||||
pub fn new(scheme: impl Into<String>, org: impl Into<String>, name: impl Into<String>) -> Self {
|
||||
Self {
|
||||
scheme: scheme.into(),
|
||||
org: org.into(),
|
||||
name: name.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// True when this id parsed from a bare `org/name` (no scheme
|
||||
/// prefix). The harness substitutes its configured default in
|
||||
/// `with_default_scheme` before resolving against a registry.
|
||||
pub fn is_scheme_unset(&self) -> bool {
|
||||
self.scheme.is_empty()
|
||||
}
|
||||
|
||||
/// Substitute `default` for an empty scheme. No-op when the scheme
|
||||
/// is already set. Returns self by value so it composes neatly:
|
||||
/// `id.parse::<ModelSourceId>()?.with_default_scheme("huggingface")`.
|
||||
pub fn with_default_scheme(mut self, default: &str) -> Self {
|
||||
if self.scheme.is_empty() {
|
||||
self.scheme = default.to_string();
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// The `org/name` half — what an hf-hub `Api::model(...)` call
|
||||
/// expects regardless of which scheme/endpoint we're hitting.
|
||||
pub fn repo_path(&self) -> String {
|
||||
format!("{}/{}", self.org, self.name)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ModelSourceId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
if self.scheme.is_empty() {
|
||||
write!(f, "{}/{}", self.org, self.name)
|
||||
} else {
|
||||
write!(f, "{}:{}/{}", self.scheme, self.org, self.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for ModelSourceId {
|
||||
type Err = ParseError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
if s.is_empty() {
|
||||
return Err(ParseError::Empty);
|
||||
}
|
||||
// Scheme split. Only the *first* colon counts — anything after
|
||||
// belongs to org/name (and would be rejected separately because
|
||||
// `:` isn't allowed there).
|
||||
let (scheme, rest) = match s.split_once(':') {
|
||||
Some((scheme, rest)) => {
|
||||
if scheme.is_empty() {
|
||||
return Err(ParseError::EmptyScheme(s.to_string()));
|
||||
}
|
||||
if scheme.contains('/') {
|
||||
return Err(ParseError::SchemeContainsSlash(s.to_string()));
|
||||
}
|
||||
(scheme.to_string(), rest)
|
||||
}
|
||||
None => (String::new(), s),
|
||||
};
|
||||
let (org, name) = rest
|
||||
.split_once('/')
|
||||
.ok_or_else(|| ParseError::MissingSlash(s.to_string()))?;
|
||||
if org.is_empty() {
|
||||
return Err(ParseError::EmptyOrg(s.to_string()));
|
||||
}
|
||||
if name.is_empty() {
|
||||
return Err(ParseError::EmptyName(s.to_string()));
|
||||
}
|
||||
if name.contains(':') {
|
||||
return Err(ParseError::NameContainsColon(s.to_string()));
|
||||
}
|
||||
Ok(Self {
|
||||
scheme,
|
||||
org: org.to_string(),
|
||||
name: name.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parses_qualified() {
|
||||
let id: ModelSourceId = "huggingface:Qwen/Qwen3.6-27B".parse().unwrap();
|
||||
assert_eq!(id.scheme, "huggingface");
|
||||
assert_eq!(id.org, "Qwen");
|
||||
assert_eq!(id.name, "Qwen3.6-27B");
|
||||
assert_eq!(id.repo_path(), "Qwen/Qwen3.6-27B");
|
||||
assert!(!id.is_scheme_unset());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_helexa_scheme() {
|
||||
let id: ModelSourceId = "helexa:SomeOperator/Qwen3.6-27B-Uncensored"
|
||||
.parse()
|
||||
.unwrap();
|
||||
assert_eq!(id.scheme, "helexa");
|
||||
assert_eq!(id.org, "SomeOperator");
|
||||
assert_eq!(id.name, "Qwen3.6-27B-Uncensored");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_bare_id_with_empty_scheme() {
|
||||
let id: ModelSourceId = "Qwen/Qwen3-30B-A3B-Instruct".parse().unwrap();
|
||||
assert_eq!(id.scheme, "");
|
||||
assert_eq!(id.org, "Qwen");
|
||||
assert_eq!(id.name, "Qwen3-30B-A3B-Instruct");
|
||||
assert!(id.is_scheme_unset());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn substitutes_default_scheme_only_when_unset() {
|
||||
let id: ModelSourceId = "Qwen/Q3".parse().unwrap();
|
||||
assert_eq!(id.with_default_scheme("huggingface").scheme, "huggingface");
|
||||
|
||||
let id: ModelSourceId = "helexa:Qwen/Q3".parse().unwrap();
|
||||
assert_eq!(
|
||||
id.with_default_scheme("huggingface").scheme,
|
||||
"helexa",
|
||||
"default substitution must not override an explicit scheme"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_roundtrips_qualified_id() {
|
||||
let s = "helexa:Helexa/Qwen3.6-27B";
|
||||
let id: ModelSourceId = s.parse().unwrap();
|
||||
assert_eq!(id.to_string(), s);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_roundtrips_bare_id() {
|
||||
let s = "Qwen/Q3";
|
||||
let id: ModelSourceId = s.parse().unwrap();
|
||||
assert_eq!(id.to_string(), s);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_empty() {
|
||||
assert_eq!("".parse::<ModelSourceId>().unwrap_err(), ParseError::Empty);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_missing_slash() {
|
||||
match "Qwen".parse::<ModelSourceId>().unwrap_err() {
|
||||
ParseError::MissingSlash(s) => assert_eq!(s, "Qwen"),
|
||||
other => panic!("expected MissingSlash, got {other:?}"),
|
||||
}
|
||||
match "huggingface:Qwen".parse::<ModelSourceId>().unwrap_err() {
|
||||
ParseError::MissingSlash(s) => assert_eq!(s, "huggingface:Qwen"),
|
||||
other => panic!("expected MissingSlash, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_empty_scheme() {
|
||||
match ":Qwen/Q3".parse::<ModelSourceId>().unwrap_err() {
|
||||
ParseError::EmptyScheme(s) => assert_eq!(s, ":Qwen/Q3"),
|
||||
other => panic!("expected EmptyScheme, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_scheme_with_slash() {
|
||||
match "hugg/ingface:Q/N".parse::<ModelSourceId>().unwrap_err() {
|
||||
ParseError::SchemeContainsSlash(s) => assert_eq!(s, "hugg/ingface:Q/N"),
|
||||
other => panic!("expected SchemeContainsSlash, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_empty_org_or_name() {
|
||||
match "huggingface:/N".parse::<ModelSourceId>().unwrap_err() {
|
||||
ParseError::EmptyOrg(_) => {}
|
||||
other => panic!("expected EmptyOrg, got {other:?}"),
|
||||
}
|
||||
match "huggingface:Q/".parse::<ModelSourceId>().unwrap_err() {
|
||||
ParseError::EmptyName(_) => {}
|
||||
other => panic!("expected EmptyName, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_name_with_colon() {
|
||||
match "huggingface:Q/N:weird"
|
||||
.parse::<ModelSourceId>()
|
||||
.unwrap_err()
|
||||
{
|
||||
ParseError::NameContainsColon(s) => assert_eq!(s, "huggingface:Q/N:weird"),
|
||||
other => panic!("expected NameContainsColon, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde_roundtrips_via_struct() {
|
||||
// We serialize as a struct (scheme/org/name fields) so the
|
||||
// shape is self-describing in API payloads. Callers that want
|
||||
// the compact `scheme:org/name` string use `Display`/`FromStr`.
|
||||
let id = ModelSourceId::new("helexa", "Helexa", "Qwen3.6-27B");
|
||||
let json = serde_json::to_string(&id).unwrap();
|
||||
let back: ModelSourceId = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(back, id);
|
||||
}
|
||||
}
|
||||
@@ -75,7 +75,11 @@ pub fn openai_to_anthropic(resp: ChatCompletionResponse) -> MessagesResponse {
|
||||
MessageContent::Text(t) => t,
|
||||
MessageContent::Parts(parts) => serde_json::to_string(&parts).unwrap_or_default(),
|
||||
};
|
||||
let stop = c.finish_reason.map(|r| map_stop_reason(&r));
|
||||
let stop = c.finish_reason.map(|r| match r.as_str() {
|
||||
"stop" => "end_turn".to_string(),
|
||||
"length" => "max_tokens".to_string(),
|
||||
other => other.to_string(),
|
||||
});
|
||||
(text, stop)
|
||||
}
|
||||
None => (String::new(), None),
|
||||
@@ -104,374 +108,3 @@ pub fn openai_to_anthropic(resp: ChatCompletionResponse) -> MessagesResponse {
|
||||
extra: Value::Null,
|
||||
}
|
||||
}
|
||||
|
||||
// ── Streaming SSE translation (#24) ──────────────────────────────────
|
||||
|
||||
/// Map an OpenAI `finish_reason` to an Anthropic `stop_reason`.
|
||||
pub fn map_stop_reason(openai: &str) -> String {
|
||||
match openai {
|
||||
"stop" => "end_turn".to_string(),
|
||||
"length" => "max_tokens".to_string(),
|
||||
"tool_calls" => "tool_use".to_string(),
|
||||
other => other.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Stateful OpenAI-SSE → Anthropic-SSE event translator.
|
||||
///
|
||||
/// Feed each parsed OpenAI [`crate::openai::ChatCompletionChunk`] to
|
||||
/// [`on_chunk`](Self::on_chunk) and call [`finish`](Self::finish) on
|
||||
/// `[DONE]` (or upstream EOF); both return ordered
|
||||
/// `(event_name, payload)` pairs ready to be framed as
|
||||
/// `event: <name>\ndata: <payload>\n\n`. The translation is stateless
|
||||
/// across requests — one instance per stream — and never buffers
|
||||
/// content: every text delta maps to a `content_block_delta`
|
||||
/// immediately.
|
||||
///
|
||||
/// Event sequence produced (per Anthropic's streaming spec):
|
||||
/// `message_start` → `content_block_start` / `content_block_delta`* /
|
||||
/// `content_block_stop` (text and `tool_use` blocks, indexed) →
|
||||
/// `message_delta` (stop_reason + output usage) → `message_stop`.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct AnthropicStreamTranslator {
|
||||
started: bool,
|
||||
finished: bool,
|
||||
/// Index of the currently-open content block, with its kind.
|
||||
open_block: Option<(u32, OpenBlock)>,
|
||||
next_index: u32,
|
||||
stop_reason: Option<String>,
|
||||
usage: Option<Usage>,
|
||||
/// Visible text deltas counted as an output-token estimate for
|
||||
/// streams whose upstream never sends a usage frame (neuron emits
|
||||
/// one chunk per token, so this is exact there).
|
||||
text_deltas: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
enum OpenBlock {
|
||||
Text,
|
||||
ToolUse,
|
||||
}
|
||||
|
||||
impl AnthropicStreamTranslator {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn on_chunk(&mut self, chunk: &crate::openai::ChatCompletionChunk) -> Vec<(String, Value)> {
|
||||
let mut out = Vec::new();
|
||||
if !self.started {
|
||||
self.started = true;
|
||||
out.push((
|
||||
"message_start".to_string(),
|
||||
json!({
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
// Upstream ids are opaque to Anthropic clients;
|
||||
// prefix for shape-compatibility with msg_* ids.
|
||||
"id": format!("msg_{}", chunk.id),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"model": chunk.model,
|
||||
"stop_reason": null,
|
||||
"stop_sequence": null,
|
||||
// Input tokens are unknown until (if ever) a
|
||||
// usage frame arrives; corrected in
|
||||
// message_delta. Anthropic clients sum deltas.
|
||||
"usage": { "input_tokens": 0, "output_tokens": 0 }
|
||||
}
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(usage) = &chunk.usage {
|
||||
self.usage = Some(usage.clone());
|
||||
}
|
||||
|
||||
for choice in &chunk.choices {
|
||||
if let Some(text) = choice.delta.get("content").and_then(Value::as_str)
|
||||
&& !text.is_empty()
|
||||
{
|
||||
self.ensure_text_block(&mut out);
|
||||
self.text_deltas += 1;
|
||||
let index = self.open_block.as_ref().map(|(i, _)| *i).unwrap_or(0);
|
||||
out.push((
|
||||
"content_block_delta".to_string(),
|
||||
json!({
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": { "type": "text_delta", "text": text }
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(calls) = choice.delta.get("tool_calls").and_then(Value::as_array) {
|
||||
for call in calls {
|
||||
let name = call
|
||||
.get("function")
|
||||
.and_then(|f| f.get("name"))
|
||||
.and_then(Value::as_str);
|
||||
let arguments = call
|
||||
.get("function")
|
||||
.and_then(|f| f.get("arguments"))
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default();
|
||||
if let Some(name) = name {
|
||||
// A named entry begins a new tool_use block.
|
||||
self.close_open_block(&mut out);
|
||||
let id = call
|
||||
.get("id")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("toolu_unknown");
|
||||
let index = self.next_index;
|
||||
self.next_index += 1;
|
||||
self.open_block = Some((index, OpenBlock::ToolUse));
|
||||
out.push((
|
||||
"content_block_start".to_string(),
|
||||
json!({
|
||||
"type": "content_block_start",
|
||||
"index": index,
|
||||
"content_block": {
|
||||
"type": "tool_use",
|
||||
"id": id,
|
||||
"name": name,
|
||||
"input": {}
|
||||
}
|
||||
}),
|
||||
));
|
||||
}
|
||||
if !arguments.is_empty()
|
||||
&& let Some((index, OpenBlock::ToolUse)) = &self.open_block
|
||||
{
|
||||
out.push((
|
||||
"content_block_delta".to_string(),
|
||||
json!({
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": {
|
||||
"type": "input_json_delta",
|
||||
"partial_json": arguments
|
||||
}
|
||||
}),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(reason) = &choice.finish_reason {
|
||||
self.stop_reason = Some(map_stop_reason(reason));
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Close the stream: emits the trailing block-stop, message_delta
|
||||
/// (stop_reason + output usage) and message_stop. Idempotent.
|
||||
pub fn finish(&mut self) -> Vec<(String, Value)> {
|
||||
let mut out = Vec::new();
|
||||
if self.finished || !self.started {
|
||||
self.finished = true;
|
||||
return out;
|
||||
}
|
||||
self.finished = true;
|
||||
self.close_open_block(&mut out);
|
||||
let output_tokens = self
|
||||
.usage
|
||||
.as_ref()
|
||||
.map(|u| u.completion_tokens)
|
||||
.unwrap_or(self.text_deltas);
|
||||
let mut usage = json!({ "output_tokens": output_tokens });
|
||||
if let Some(u) = &self.usage {
|
||||
usage["input_tokens"] = json!(u.prompt_tokens);
|
||||
}
|
||||
out.push((
|
||||
"message_delta".to_string(),
|
||||
json!({
|
||||
"type": "message_delta",
|
||||
"delta": {
|
||||
"stop_reason": self.stop_reason.as_deref().unwrap_or("end_turn"),
|
||||
"stop_sequence": null
|
||||
},
|
||||
"usage": usage
|
||||
}),
|
||||
));
|
||||
out.push((
|
||||
"message_stop".to_string(),
|
||||
json!({ "type": "message_stop" }),
|
||||
));
|
||||
out
|
||||
}
|
||||
|
||||
fn ensure_text_block(&mut self, out: &mut Vec<(String, Value)>) {
|
||||
match &self.open_block {
|
||||
Some((_, OpenBlock::Text)) => {}
|
||||
_ => {
|
||||
self.close_open_block(out);
|
||||
let index = self.next_index;
|
||||
self.next_index += 1;
|
||||
self.open_block = Some((index, OpenBlock::Text));
|
||||
out.push((
|
||||
"content_block_start".to_string(),
|
||||
json!({
|
||||
"type": "content_block_start",
|
||||
"index": index,
|
||||
"content_block": { "type": "text", "text": "" }
|
||||
}),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn close_open_block(&mut self, out: &mut Vec<(String, Value)>) {
|
||||
if let Some((index, _)) = self.open_block.take() {
|
||||
out.push((
|
||||
"content_block_stop".to_string(),
|
||||
json!({ "type": "content_block_stop", "index": index }),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod stream_tests {
|
||||
use super::*;
|
||||
use crate::openai::{ChatCompletionChunk, ChunkChoice};
|
||||
|
||||
fn chunk(delta: Value, finish: Option<&str>) -> ChatCompletionChunk {
|
||||
ChatCompletionChunk {
|
||||
id: "abc123".into(),
|
||||
object: "chat.completion.chunk".into(),
|
||||
created: 1,
|
||||
model: "Qwen/Qwen3-8B".into(),
|
||||
choices: vec![ChunkChoice {
|
||||
index: 0,
|
||||
delta,
|
||||
finish_reason: finish.map(String::from),
|
||||
extra: Value::Null,
|
||||
}],
|
||||
usage: None,
|
||||
extra: Value::Null,
|
||||
}
|
||||
}
|
||||
|
||||
fn names(events: &[(String, Value)]) -> Vec<&str> {
|
||||
events.iter().map(|(n, _)| n.as_str()).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn text_stream_produces_full_anthropic_sequence() {
|
||||
let mut t = AnthropicStreamTranslator::new();
|
||||
let mut all = Vec::new();
|
||||
all.extend(t.on_chunk(&chunk(json!({"role": "assistant"}), None)));
|
||||
all.extend(t.on_chunk(&chunk(json!({"content": "Hel"}), None)));
|
||||
all.extend(t.on_chunk(&chunk(json!({"content": "lo"}), None)));
|
||||
all.extend(t.on_chunk(&chunk(json!({}), Some("stop"))));
|
||||
all.extend(t.finish());
|
||||
|
||||
assert_eq!(
|
||||
names(&all),
|
||||
vec![
|
||||
"message_start",
|
||||
"content_block_start",
|
||||
"content_block_delta",
|
||||
"content_block_delta",
|
||||
"content_block_stop",
|
||||
"message_delta",
|
||||
"message_stop",
|
||||
]
|
||||
);
|
||||
// message_start carries role/model; deltas carry the text.
|
||||
assert_eq!(all[0].1["message"]["model"], "Qwen/Qwen3-8B");
|
||||
assert_eq!(all[2].1["delta"]["text"], "Hel");
|
||||
assert_eq!(all[3].1["delta"]["text"], "lo");
|
||||
// stop → end_turn; without a usage frame the output count
|
||||
// falls back to the delta count (engine-exact for neuron's
|
||||
// one-chunk-per-token streams).
|
||||
let md = &all[5].1;
|
||||
assert_eq!(md["delta"]["stop_reason"], "end_turn");
|
||||
assert_eq!(md["usage"]["output_tokens"], 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn length_maps_to_max_tokens_and_missing_finish_defaults_to_end_turn() {
|
||||
let mut t = AnthropicStreamTranslator::new();
|
||||
t.on_chunk(&chunk(json!({"content": "x"}), Some("length")));
|
||||
let fin = t.finish();
|
||||
assert_eq!(fin[1].1["delta"]["stop_reason"], "max_tokens");
|
||||
|
||||
let mut t2 = AnthropicStreamTranslator::new();
|
||||
t2.on_chunk(&chunk(json!({"content": "x"}), None));
|
||||
let fin2 = t2.finish();
|
||||
assert_eq!(fin2[1].1["delta"]["stop_reason"], "end_turn");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_becomes_tool_use_block() {
|
||||
let mut t = AnthropicStreamTranslator::new();
|
||||
let mut all = Vec::new();
|
||||
all.extend(t.on_chunk(&chunk(json!({"content": "Let me check."}), None)));
|
||||
all.extend(t.on_chunk(&chunk(
|
||||
json!({"tool_calls": [{
|
||||
"index": 0,
|
||||
"id": "call_7",
|
||||
"function": {"name": "get_weather", "arguments": "{\"city\":\"Brno\"}"}
|
||||
}]}),
|
||||
None,
|
||||
)));
|
||||
all.extend(t.on_chunk(&chunk(json!({}), Some("tool_calls"))));
|
||||
all.extend(t.finish());
|
||||
|
||||
assert_eq!(
|
||||
names(&all),
|
||||
vec![
|
||||
"message_start",
|
||||
"content_block_start", // text
|
||||
"content_block_delta", // text delta
|
||||
"content_block_stop", // text closed by tool block
|
||||
"content_block_start", // tool_use
|
||||
"content_block_delta", // input_json_delta
|
||||
"content_block_stop",
|
||||
"message_delta",
|
||||
"message_stop",
|
||||
]
|
||||
);
|
||||
let tool_start = &all[4].1;
|
||||
assert_eq!(tool_start["content_block"]["type"], "tool_use");
|
||||
assert_eq!(tool_start["content_block"]["id"], "call_7");
|
||||
assert_eq!(tool_start["content_block"]["name"], "get_weather");
|
||||
assert_eq!(tool_start["index"], 1);
|
||||
assert_eq!(all[5].1["delta"]["partial_json"], "{\"city\":\"Brno\"}");
|
||||
assert_eq!(all[7].1["delta"]["stop_reason"], "tool_use");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_frame_feeds_message_delta() {
|
||||
let mut t = AnthropicStreamTranslator::new();
|
||||
t.on_chunk(&chunk(json!({"content": "hi"}), Some("stop")));
|
||||
let mut usage_chunk = chunk(json!({}), None);
|
||||
usage_chunk.choices.clear();
|
||||
usage_chunk.usage = Some(crate::openai::Usage {
|
||||
prompt_tokens: 225,
|
||||
completion_tokens: 42,
|
||||
total_tokens: 267,
|
||||
});
|
||||
t.on_chunk(&usage_chunk);
|
||||
let fin = t.finish();
|
||||
let md = &fin[1].1;
|
||||
assert_eq!(md["usage"]["output_tokens"], 42);
|
||||
assert_eq!(md["usage"]["input_tokens"], 225);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn finish_is_idempotent_and_silent_without_start() {
|
||||
let mut t = AnthropicStreamTranslator::new();
|
||||
assert!(t.finish().is_empty(), "no events for an empty stream");
|
||||
assert!(t.finish().is_empty());
|
||||
|
||||
let mut t2 = AnthropicStreamTranslator::new();
|
||||
t2.on_chunk(&chunk(json!({"content": "x"}), None));
|
||||
assert!(!t2.finish().is_empty());
|
||||
assert!(t2.finish().is_empty(), "second finish must emit nothing");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ tokio-stream.workspace = true
|
||||
eventsource-stream.workspace = true
|
||||
bytes = "1"
|
||||
urlencoding = "2"
|
||||
url = "2"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["test-util"] }
|
||||
|
||||
@@ -1,178 +0,0 @@
|
||||
//! Streaming Anthropic SSE translation (#24).
|
||||
//!
|
||||
//! The `/v1/messages` handler translates the request envelope to
|
||||
//! OpenAI before proxying (see `cortex_core::translate`); this module
|
||||
//! completes the round trip for `stream: true` — the upstream OpenAI
|
||||
//! SSE stream is re-framed, event by event, into Anthropic's
|
||||
//! `message_start` / `content_block_*` / `message_delta` /
|
||||
//! `message_stop` sequence as it arrives. True streaming: each
|
||||
//! upstream chunk is translated and forwarded immediately; nothing is
|
||||
//! buffered beyond the current SSE event's bytes.
|
||||
//!
|
||||
//! The translation state machine itself is pure and lives in
|
||||
//! [`cortex_core::translate::AnthropicStreamTranslator`]; this module
|
||||
//! owns the wire concerns — splitting the upstream byte stream into
|
||||
//! SSE events, parsing `data:` payloads, and framing the translated
|
||||
//! events as `event: <name>\ndata: <json>\n\n`.
|
||||
|
||||
use axum::body::Body;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::Response;
|
||||
use bytes::Bytes;
|
||||
use cortex_core::openai::ChatCompletionChunk;
|
||||
use cortex_core::translate::AnthropicStreamTranslator;
|
||||
use futures::StreamExt;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
/// Forward the translated OpenAI request to the upstream node and
|
||||
/// return the response translated to Anthropic SSE framing.
|
||||
pub async fn stream_translated(
|
||||
client: &reqwest::Client,
|
||||
endpoint: &str,
|
||||
openai_body: axum::body::Bytes,
|
||||
model_id: &str,
|
||||
node_name: &str,
|
||||
) -> Response {
|
||||
let url = format!("{endpoint}/v1/chat/completions");
|
||||
tracing::info!(
|
||||
handler = "anthropic_messages",
|
||||
model = %model_id,
|
||||
node = %node_name,
|
||||
url = %url,
|
||||
"proxying streaming request (anthropic SSE translation)"
|
||||
);
|
||||
|
||||
let upstream = match client
|
||||
.post(&url)
|
||||
.header("content-type", "application/json")
|
||||
.body(openai_body)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
handler = "anthropic_messages",
|
||||
node = %node_name,
|
||||
url = %url,
|
||||
error = %e,
|
||||
"anthropic stream: upstream request failed"
|
||||
);
|
||||
return anthropic_error(StatusCode::BAD_GATEWAY, "upstream request failed");
|
||||
}
|
||||
};
|
||||
|
||||
let status = upstream.status();
|
||||
if !status.is_success() {
|
||||
tracing::warn!(
|
||||
handler = "anthropic_messages",
|
||||
node = %node_name,
|
||||
url = %url,
|
||||
status = status.as_u16(),
|
||||
"anthropic stream: upstream returned non-2xx"
|
||||
);
|
||||
return anthropic_error(
|
||||
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY),
|
||||
"upstream returned an error",
|
||||
);
|
||||
}
|
||||
|
||||
// Bounded channel: a slow client back-pressures the pump task,
|
||||
// which back-pressures the upstream read — same propagation
|
||||
// discipline as neuron's own projectors.
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Bytes, std::convert::Infallible>>(32);
|
||||
let node = node_name.to_string();
|
||||
tokio::spawn(async move {
|
||||
let mut upstream = upstream.bytes_stream();
|
||||
let mut translator = AnthropicStreamTranslator::new();
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
let mut done = false;
|
||||
|
||||
'outer: while let Some(block) = upstream.next().await {
|
||||
let block = match block {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
tracing::warn!(node = %node, error = %e, "anthropic stream: upstream read failed mid-stream");
|
||||
break;
|
||||
}
|
||||
};
|
||||
buf.extend_from_slice(&block);
|
||||
// SSE events are separated by a blank line.
|
||||
while let Some(pos) = find_event_boundary(&buf) {
|
||||
let event: Vec<u8> = buf.drain(..pos + 2).collect();
|
||||
let text = String::from_utf8_lossy(&event);
|
||||
for line in text.lines() {
|
||||
let Some(data) = line.strip_prefix("data:") else {
|
||||
continue;
|
||||
};
|
||||
let data = data.trim();
|
||||
if data == "[DONE]" {
|
||||
done = true;
|
||||
if !send_frames(&tx, translator.finish()).await {
|
||||
break 'outer;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
let Ok(chunk) = serde_json::from_str::<ChatCompletionChunk>(data) else {
|
||||
tracing::debug!(node = %node, "anthropic stream: unparsable upstream frame skipped");
|
||||
continue;
|
||||
};
|
||||
if !send_frames(&tx, translator.on_chunk(&chunk)).await {
|
||||
break 'outer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Upstream ended without [DONE] (error or truncation): still
|
||||
// close the Anthropic event sequence so clients aren't left
|
||||
// with an unterminated message.
|
||||
if !done {
|
||||
let _ = send_frames(&tx, translator.finish()).await;
|
||||
}
|
||||
});
|
||||
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("content-type", "text/event-stream")
|
||||
.header("cache-control", "no-cache")
|
||||
.body(Body::from_stream(ReceiverStream::new(rx)))
|
||||
.unwrap_or_else(|_| {
|
||||
anthropic_error(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"failed to build response",
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// `\n\n` boundary of the first complete SSE event in `buf`, if any.
|
||||
fn find_event_boundary(buf: &[u8]) -> Option<usize> {
|
||||
buf.windows(2).position(|w| w == b"\n\n")
|
||||
}
|
||||
|
||||
/// Render translated events as SSE frames and send them. Returns
|
||||
/// `false` when the client has gone away (receiver dropped).
|
||||
async fn send_frames(
|
||||
tx: &tokio::sync::mpsc::Sender<Result<Bytes, std::convert::Infallible>>,
|
||||
events: Vec<(String, serde_json::Value)>,
|
||||
) -> bool {
|
||||
for (name, payload) in events {
|
||||
let frame = format!("event: {name}\ndata: {payload}\n\n");
|
||||
if tx.send(Ok(Bytes::from(frame))).await.is_err() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Anthropic-shaped error body (`{"type":"error","error":{...}}`).
|
||||
fn anthropic_error(status: StatusCode, message: &str) -> Response {
|
||||
let body = serde_json::json!({
|
||||
"type": "error",
|
||||
"error": { "type": "api_error", "message": message }
|
||||
});
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(body.to_string()))
|
||||
.expect("static error response must build")
|
||||
}
|
||||
@@ -20,7 +20,6 @@ pub fn api_routes() -> Router<Arc<CortexState>> {
|
||||
Router::new()
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/v1/completions", post(completions))
|
||||
.route("/v1/responses", post(responses))
|
||||
.route("/v1/models", get(list_models))
|
||||
.route("/v1/messages", post(anthropic_messages))
|
||||
.route("/health", get(health))
|
||||
@@ -35,94 +34,23 @@ async fn chat_completions(
|
||||
) -> Response {
|
||||
let model_id = match extract_model(&body) {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
tracing::warn!(
|
||||
handler = "chat_completions",
|
||||
"rejected: missing 'model' field in request body"
|
||||
);
|
||||
return error_response(400, "missing 'model' field in request body");
|
||||
}
|
||||
None => return error_response(400, "missing 'model' field in request body"),
|
||||
};
|
||||
|
||||
let route = match router::resolve(&fleet, &model_id).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
handler = "chat_completions",
|
||||
model = %model_id,
|
||||
error = %e,
|
||||
"route resolve failed"
|
||||
);
|
||||
// RouteError's Display strings are short and informative
|
||||
// ("model 'X' not found...", "no healthy nodes available")
|
||||
// — fine to surface to the caller. The warn above carries
|
||||
// any extra context for operators.
|
||||
return error_response(e.http_status(), &e.to_string());
|
||||
}
|
||||
Err(e) => return error_response(404, &e.to_string()),
|
||||
};
|
||||
|
||||
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
|
||||
touch_model(&fleet, &route.node_name, &model_id).await;
|
||||
|
||||
let body = rewrite_model_in_body(body, &route.resolved_model_id);
|
||||
proxy_with_metrics(
|
||||
&fleet,
|
||||
&route,
|
||||
"/v1/chat/completions",
|
||||
headers,
|
||||
body,
|
||||
&route.resolved_model_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// `POST /v1/responses` — proxy to the appropriate backend node.
|
||||
///
|
||||
/// Same routing shape as [`chat_completions`]: extract `model` from
|
||||
/// the body, resolve to a node, forward verbatim. No translation —
|
||||
/// neuron speaks the Responses API natively (see
|
||||
/// `crates/neuron/src/wire/openai_responses.rs`), so the gateway is
|
||||
/// a pass-through. Streaming and non-streaming are handled
|
||||
/// identically; the upstream `Content-Type` (text/event-stream vs.
|
||||
/// application/json) propagates through the proxy.
|
||||
async fn responses(
|
||||
State(fleet): State<Arc<CortexState>>,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> Response {
|
||||
let model_id = match extract_model(&body) {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
tracing::warn!(
|
||||
handler = "responses",
|
||||
"rejected: missing 'model' field in request body"
|
||||
);
|
||||
return error_response(400, "missing 'model' field in request body");
|
||||
}
|
||||
};
|
||||
|
||||
let route = match router::resolve(&fleet, &model_id).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
handler = "responses",
|
||||
model = %model_id,
|
||||
error = %e,
|
||||
"route resolve failed"
|
||||
);
|
||||
return error_response(e.http_status(), &e.to_string());
|
||||
}
|
||||
};
|
||||
|
||||
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
|
||||
|
||||
let body = rewrite_model_in_body(body, &route.resolved_model_id);
|
||||
proxy_with_metrics(
|
||||
&fleet,
|
||||
&route,
|
||||
"/v1/responses",
|
||||
headers,
|
||||
body,
|
||||
&route.resolved_model_id,
|
||||
&model_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -135,63 +63,29 @@ async fn completions(
|
||||
) -> Response {
|
||||
let model_id = match extract_model(&body) {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
tracing::warn!(
|
||||
handler = "completions",
|
||||
"rejected: missing 'model' field in request body"
|
||||
);
|
||||
return error_response(400, "missing 'model' field in request body");
|
||||
}
|
||||
None => return error_response(400, "missing 'model' field in request body"),
|
||||
};
|
||||
|
||||
let route = match router::resolve(&fleet, &model_id).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
handler = "completions",
|
||||
model = %model_id,
|
||||
error = %e,
|
||||
"route resolve failed"
|
||||
);
|
||||
// RouteError's Display strings are short and informative
|
||||
// ("model 'X' not found...", "no healthy nodes available")
|
||||
// — fine to surface to the caller. The warn above carries
|
||||
// any extra context for operators.
|
||||
return error_response(e.http_status(), &e.to_string());
|
||||
}
|
||||
Err(e) => return error_response(404, &e.to_string()),
|
||||
};
|
||||
|
||||
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
|
||||
touch_model(&fleet, &route.node_name, &model_id).await;
|
||||
|
||||
let body = rewrite_model_in_body(body, &route.resolved_model_id);
|
||||
proxy_with_metrics(
|
||||
&fleet,
|
||||
&route,
|
||||
"/v1/completions",
|
||||
headers,
|
||||
body,
|
||||
&route.resolved_model_id,
|
||||
)
|
||||
.await
|
||||
proxy_with_metrics(&fleet, &route, "/v1/completions", headers, body, &model_id).await
|
||||
}
|
||||
|
||||
/// `POST /v1/messages` — accept Anthropic format, translate, proxy, translate back.
|
||||
async fn anthropic_messages(
|
||||
State(fleet): State<Arc<CortexState>>,
|
||||
_headers: HeaderMap,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> Response {
|
||||
// Parse as Anthropic request.
|
||||
let anth_req: cortex_core::anthropic::MessagesRequest = match serde_json::from_slice(&body) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
handler = "anthropic_messages",
|
||||
error = %e,
|
||||
"rejected: invalid Anthropic request body"
|
||||
);
|
||||
return error_response(400, "invalid Anthropic request body");
|
||||
}
|
||||
Err(e) => return error_response(400, &format!("invalid Anthropic request: {e}")),
|
||||
};
|
||||
|
||||
let model_id = anth_req.model.clone();
|
||||
@@ -201,43 +95,18 @@ async fn anthropic_messages(
|
||||
let openai_req = cortex_core::translate::anthropic_to_openai(anth_req);
|
||||
let openai_body = match serde_json::to_vec(&openai_req) {
|
||||
Ok(b) => Bytes::from(b),
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
handler = "anthropic_messages",
|
||||
model = %model_id,
|
||||
error = %e,
|
||||
"internal: failed to serialise translated OpenAI request"
|
||||
);
|
||||
return error_response(500, "internal translation error");
|
||||
}
|
||||
Err(e) => return error_response(500, &format!("translation error: {e}")),
|
||||
};
|
||||
|
||||
let route = match router::resolve(&fleet, &model_id).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
handler = "anthropic_messages",
|
||||
model = %model_id,
|
||||
error = %e,
|
||||
"route resolve failed"
|
||||
);
|
||||
// RouteError's Display strings are short and informative
|
||||
// ("model 'X' not found...", "no healthy nodes available")
|
||||
// — fine to surface to the caller. The warn above carries
|
||||
// any extra context for operators.
|
||||
return error_response(e.http_status(), &e.to_string());
|
||||
}
|
||||
Err(e) => return error_response(404, &e.to_string()),
|
||||
};
|
||||
|
||||
touch_model(&fleet, &route.node_name, &route.resolved_model_id).await;
|
||||
|
||||
// Swap the alias for the concrete id in the translated body so
|
||||
// neuron's harness sees a model name that matches what it has
|
||||
// loaded.
|
||||
let openai_body = rewrite_model_in_body(openai_body, &route.resolved_model_id);
|
||||
touch_model(&fleet, &route.node_name, &model_id).await;
|
||||
|
||||
let labels = [
|
||||
("model", route.resolved_model_id.clone()),
|
||||
("model", model_id.clone()),
|
||||
("node", route.node_name.clone()),
|
||||
];
|
||||
metrics::counter!("cortex_requests_total", &labels).increment(1);
|
||||
@@ -247,37 +116,31 @@ async fn anthropic_messages(
|
||||
let start = Instant::now();
|
||||
|
||||
if is_streaming {
|
||||
// Anthropic SSE translation (#24): upstream speaks OpenAI SSE;
|
||||
// re-frame it event-by-event into Anthropic's message_start /
|
||||
// content_block_* / message_delta / message_stop sequence.
|
||||
let resp = crate::anthropic_sse::stream_translated(
|
||||
// TODO: streaming Anthropic translation requires converting SSE format.
|
||||
// For now, proxy the OpenAI SSE stream directly (clients that can handle
|
||||
// OpenAI SSE will work; full Anthropic SSE translation is a follow-up).
|
||||
let result = proxy::forward_request(
|
||||
&fleet.http_client,
|
||||
&route.endpoint,
|
||||
&route,
|
||||
"/v1/chat/completions",
|
||||
headers,
|
||||
openai_body,
|
||||
&model_id,
|
||||
&route.node_name,
|
||||
)
|
||||
.await;
|
||||
metrics::histogram!("cortex_request_duration_seconds", &labels)
|
||||
.record(start.elapsed().as_secs_f64());
|
||||
if !resp.status().is_success() {
|
||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||
match result {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||
e.into_response()
|
||||
}
|
||||
}
|
||||
resp
|
||||
} else {
|
||||
// Non-streaming: proxy, buffer full response, translate back to Anthropic.
|
||||
let target_url = format!("{}/v1/chat/completions", route.endpoint);
|
||||
tracing::info!(
|
||||
handler = "anthropic_messages",
|
||||
model = %model_id,
|
||||
node = %route.node_name,
|
||||
url = %target_url,
|
||||
cold_start = route.cold_start,
|
||||
"proxying request"
|
||||
);
|
||||
let upstream_resp = fleet
|
||||
.http_client
|
||||
.post(&target_url)
|
||||
.post(format!("{}/v1/chat/completions", route.endpoint))
|
||||
.body(openai_body)
|
||||
.header("content-type", "application/json")
|
||||
.send()
|
||||
@@ -287,49 +150,22 @@ async fn anthropic_messages(
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||
tracing::warn!(
|
||||
handler = "anthropic_messages",
|
||||
model = %model_id,
|
||||
node = %route.node_name,
|
||||
url = %target_url,
|
||||
error = %e,
|
||||
"upstream request failed (network)"
|
||||
);
|
||||
return error_response(502, "upstream request failed");
|
||||
return error_response(502, &format!("upstream request failed: {e}"));
|
||||
}
|
||||
};
|
||||
|
||||
let upstream_status = upstream_resp.status();
|
||||
if !upstream_status.is_success() {
|
||||
if !upstream_resp.status().is_success() {
|
||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||
let status = upstream_status.as_u16();
|
||||
let status = upstream_resp.status().as_u16();
|
||||
let body = upstream_resp.text().await.unwrap_or_default();
|
||||
let body_snippet = body.chars().take(512).collect::<String>();
|
||||
tracing::warn!(
|
||||
handler = "anthropic_messages",
|
||||
model = %model_id,
|
||||
node = %route.node_name,
|
||||
url = %target_url,
|
||||
status,
|
||||
body = %body_snippet,
|
||||
"upstream returned non-2xx"
|
||||
);
|
||||
return error_response(status, &format!("upstream returned {status}"));
|
||||
return error_response(status, &format!("upstream error: {body}"));
|
||||
}
|
||||
|
||||
let body_bytes = match upstream_resp.bytes().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||
tracing::warn!(
|
||||
handler = "anthropic_messages",
|
||||
model = %model_id,
|
||||
node = %route.node_name,
|
||||
url = %target_url,
|
||||
error = %e,
|
||||
"failed to read upstream response body"
|
||||
);
|
||||
return error_response(502, "failed to read upstream response");
|
||||
return error_response(502, &format!("failed to read upstream response: {e}"));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -338,20 +174,7 @@ async fn anthropic_messages(
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||
let body_snippet = String::from_utf8_lossy(&body_bytes)
|
||||
.chars()
|
||||
.take(512)
|
||||
.collect::<String>();
|
||||
tracing::warn!(
|
||||
handler = "anthropic_messages",
|
||||
model = %model_id,
|
||||
node = %route.node_name,
|
||||
url = %target_url,
|
||||
error = %e,
|
||||
body = %body_snippet,
|
||||
"failed to parse upstream response as OpenAI ChatCompletionResponse"
|
||||
);
|
||||
return error_response(502, "malformed upstream response");
|
||||
return error_response(502, &format!("failed to parse upstream response: {e}"));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -362,65 +185,12 @@ async fn anthropic_messages(
|
||||
}
|
||||
}
|
||||
|
||||
/// `GET /v1/models` — union of (catalogue × topology feasibility) and
|
||||
/// (currently loaded somewhere). The result is what the fleet *could*
|
||||
/// serve, not just what's already loaded — so OpenAI-compatible tools
|
||||
/// see every model the operator has provisioned, and cortex
|
||||
/// transparently cold-loads the first time one is requested.
|
||||
/// `GET /v1/models` — aggregate models from all nodes.
|
||||
async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
||||
use std::collections::HashMap;
|
||||
let now = Utc::now().timestamp() as u64;
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let catalogue = &fleet.catalogue;
|
||||
let mut model_map: std::collections::HashMap<String, CortexModelEntry> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
let mut entries: HashMap<String, CortexModelEntry> = HashMap::new();
|
||||
|
||||
// Pass 1: catalogue × topology. For every catalogue profile, find
|
||||
// healthy neurons whose discovered devices satisfy the profile.
|
||||
// Catalogue-defined models surface here even if nothing has loaded
|
||||
// them yet — that's the point of the unified endpoint.
|
||||
for profile in &catalogue.models {
|
||||
let mut feasible_on = Vec::new();
|
||||
for node in nodes.values() {
|
||||
if !node.healthy {
|
||||
continue;
|
||||
}
|
||||
let Some(disc) = node.discovery.as_ref() else {
|
||||
continue;
|
||||
};
|
||||
if profile.is_feasible_on(&node.name, &disc.devices) {
|
||||
feasible_on.push(node.name.clone());
|
||||
}
|
||||
}
|
||||
if feasible_on.is_empty() {
|
||||
// The catalogue lists this model but no neuron's topology
|
||||
// matches — surface it as not-loaded with no feasible
|
||||
// location. Hides nothing; lets operators see why a
|
||||
// configured model isn't reachable.
|
||||
feasible_on.clear();
|
||||
}
|
||||
entries.insert(
|
||||
profile.id.clone(),
|
||||
CortexModelEntry {
|
||||
id: profile.id.clone(),
|
||||
object: "model".into(),
|
||||
created: now,
|
||||
owned_by: "helexa".into(),
|
||||
loaded: false,
|
||||
feasible_on,
|
||||
locations: Vec::new(),
|
||||
// Catalogue profiles don't declare capabilities yet;
|
||||
// the union is filled in Pass 2 from loaded locations.
|
||||
capabilities: Vec::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Pass 2: layer the actually-loaded state on top. For each
|
||||
// (node, model) entry, attach a ModelLocation. If the model isn't
|
||||
// in the catalogue, create a new CortexModelEntry from scratch —
|
||||
// cortex doesn't refuse to surface a manually-loaded model just
|
||||
// because the operator didn't enumerate it in models.toml.
|
||||
for node in nodes.values() {
|
||||
for (model_id, entry) in &node.models {
|
||||
let location = ModelLocation {
|
||||
@@ -428,121 +198,19 @@ async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
||||
status: entry.status,
|
||||
vram_estimate_mb: entry.vram_estimate_mb,
|
||||
};
|
||||
let was_loaded = matches!(entry.status, cortex_core::node::ModelStatus::Loaded);
|
||||
entries
|
||||
model_map
|
||||
.entry(model_id.clone())
|
||||
.and_modify(|e| {
|
||||
e.locations.push(location.clone());
|
||||
if was_loaded {
|
||||
e.loaded = true;
|
||||
}
|
||||
// Union the per-node capabilities so a model loaded
|
||||
// on several neurons reports every modality any of
|
||||
// them advertises.
|
||||
for cap in &entry.capabilities {
|
||||
if !e.capabilities.contains(cap) {
|
||||
e.capabilities.push(cap.clone());
|
||||
}
|
||||
}
|
||||
})
|
||||
.and_modify(|e| e.locations.push(location.clone()))
|
||||
.or_insert_with(|| CortexModelEntry {
|
||||
id: model_id.clone(),
|
||||
object: "model".into(),
|
||||
created: now,
|
||||
owned_by: "helexa".into(),
|
||||
loaded: was_loaded,
|
||||
// Not in catalogue — cortex has no opinion on
|
||||
// feasibility; leave empty.
|
||||
feasible_on: Vec::new(),
|
||||
locations: vec![location],
|
||||
capabilities: entry.capabilities.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 3: surface pre-warming models. Each neuron's `/health`
|
||||
// activation snapshot (polled separately from /models) reports
|
||||
// `in_progress` (the model currently materialising) and `pending`
|
||||
// (queued behind it). Neither appears on the neuron's `/models`
|
||||
// yet — that endpoint only knows about fully-loaded handles — so
|
||||
// without this pass a client polling `/v1/models` during pre-warm
|
||||
// sees Qwen3.6-27B with no location and concludes "not there".
|
||||
// Synthesising a Loading location instead tells clients the model
|
||||
// is on its way. Idempotent against Pass 2: if a Loading location
|
||||
// for this node already exists (shouldn't, but be safe) we skip.
|
||||
for node in nodes.values() {
|
||||
let Some(activation) = node.activation.as_ref() else {
|
||||
continue;
|
||||
};
|
||||
let mut loading_ids: Vec<&str> = Vec::new();
|
||||
if let Some(id) = activation.in_progress.as_deref() {
|
||||
loading_ids.push(id);
|
||||
}
|
||||
for id in &activation.pending {
|
||||
loading_ids.push(id.as_str());
|
||||
}
|
||||
for model_id in loading_ids {
|
||||
let location = ModelLocation {
|
||||
node: node.name.clone(),
|
||||
status: cortex_core::node::ModelStatus::Loading,
|
||||
vram_estimate_mb: None,
|
||||
};
|
||||
entries
|
||||
.entry(model_id.to_string())
|
||||
.and_modify(|e| {
|
||||
let already = e.locations.iter().any(|l| {
|
||||
l.node == node.name && l.status == cortex_core::node::ModelStatus::Loading
|
||||
});
|
||||
if !already {
|
||||
e.locations.push(location.clone());
|
||||
}
|
||||
})
|
||||
.or_insert_with(|| CortexModelEntry {
|
||||
id: model_id.to_string(),
|
||||
object: "model".into(),
|
||||
created: now,
|
||||
owned_by: "helexa".into(),
|
||||
loaded: false,
|
||||
feasible_on: Vec::new(),
|
||||
locations: vec![location],
|
||||
// A model that's only mid-prewarm has no loaded
|
||||
// location to read capabilities from yet.
|
||||
capabilities: Vec::new(),
|
||||
});
|
||||
}
|
||||
}
|
||||
let data: Vec<Value> = model_map.values().map(|e| json!(e)).collect();
|
||||
|
||||
// Pass 4: surface aliases as their own entries pointing at the
|
||||
// same locations as the target id, so a client browsing /v1/models
|
||||
// sees "helexa/small" / "helexa/balanced" / "helexa/large" (or
|
||||
// whatever the operator defined) and can request inference
|
||||
// against them directly. Aliases that point at unknown targets
|
||||
// are skipped — surfacing a dead alias would be misleading.
|
||||
for (alias, target) in &catalogue.aliases {
|
||||
let Some(target_entry) = entries.get(target).cloned() else {
|
||||
tracing::warn!(
|
||||
alias = alias,
|
||||
target = target,
|
||||
"alias points at a model not present in catalogue or fleet; skipping"
|
||||
);
|
||||
continue;
|
||||
};
|
||||
entries.insert(
|
||||
alias.clone(),
|
||||
CortexModelEntry {
|
||||
id: alias.clone(),
|
||||
object: "model".into(),
|
||||
created: now,
|
||||
owned_by: "helexa".into(),
|
||||
loaded: target_entry.loaded,
|
||||
feasible_on: target_entry.feasible_on,
|
||||
locations: target_entry.locations,
|
||||
capabilities: target_entry.capabilities,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let data: Vec<Value> = entries.values().map(|e| json!(e)).collect();
|
||||
Json(json!({
|
||||
"object": "list",
|
||||
"data": data,
|
||||
@@ -586,8 +254,7 @@ async fn proxy_with_metrics(
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
let result =
|
||||
proxy::forward_request(&fleet.http_client, route, path, headers, body, model_id).await;
|
||||
let result = proxy::forward_request(&fleet.http_client, route, path, headers, body).await;
|
||||
let duration = start.elapsed();
|
||||
|
||||
match result {
|
||||
@@ -598,9 +265,6 @@ async fn proxy_with_metrics(
|
||||
}
|
||||
Err(e) => {
|
||||
metrics::counter!("cortex_request_errors_total", &labels).increment(1);
|
||||
// proxy::forward_request already warn'd with wire-level
|
||||
// detail (target URL, error, status). ProxyError::into_response
|
||||
// now returns a generic message — no body leak.
|
||||
e.into_response()
|
||||
}
|
||||
}
|
||||
@@ -621,38 +285,6 @@ fn extract_model(body: &[u8]) -> Option<String> {
|
||||
v.get("model")?.as_str().map(|s| s.to_string())
|
||||
}
|
||||
|
||||
/// Rewrite the `model` field of an OpenAI-style JSON request body to
|
||||
/// the resolved concrete id. Returns the original bytes if `new_model`
|
||||
/// matches what's already there or the body fails to parse — the
|
||||
/// caller has already extracted `model` via `extract_model`, so a
|
||||
/// parse failure here would only happen on a body the client crafted
|
||||
/// to defeat us, and we'd rather proxy it unchanged than 500.
|
||||
///
|
||||
/// Needed because neuron rejects requests whose `model` field doesn't
|
||||
/// match a loaded model, so a client that sends `model: "helexa/small"`
|
||||
/// would hit a 404 at the harness unless we swap it for the concrete
|
||||
/// id the alias resolved to.
|
||||
fn rewrite_model_in_body(body: Bytes, new_model: &str) -> Bytes {
|
||||
let Ok(mut v) = serde_json::from_slice::<Value>(&body) else {
|
||||
return body;
|
||||
};
|
||||
let needs_rewrite = v
|
||||
.get("model")
|
||||
.and_then(|m| m.as_str())
|
||||
.map(|m| m != new_model)
|
||||
.unwrap_or(false);
|
||||
if !needs_rewrite {
|
||||
return body;
|
||||
}
|
||||
if let Value::Object(obj) = &mut v {
|
||||
obj.insert("model".into(), Value::String(new_model.to_string()));
|
||||
}
|
||||
match serde_json::to_vec(&v) {
|
||||
Ok(bytes) => Bytes::from(bytes),
|
||||
Err(_) => body,
|
||||
}
|
||||
}
|
||||
|
||||
fn error_response(status: u16, message: &str) -> Response {
|
||||
let code = axum::http::StatusCode::from_u16(status)
|
||||
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
pub mod anthropic_sse;
|
||||
pub mod evictor;
|
||||
pub mod handlers;
|
||||
pub mod metrics;
|
||||
|
||||
@@ -46,14 +46,6 @@ fn describe_metrics() {
|
||||
"Generation throughput in tokens per second"
|
||||
);
|
||||
metrics::describe_counter!("cortex_requests_total", "Total number of proxied requests");
|
||||
metrics::describe_counter!(
|
||||
"cortex_prompt_tokens_total",
|
||||
"Total prompt tokens reported by upstream usage objects"
|
||||
);
|
||||
metrics::describe_counter!(
|
||||
"cortex_completion_tokens_total",
|
||||
"Total completion tokens reported by upstream usage objects"
|
||||
);
|
||||
metrics::describe_counter!(
|
||||
"cortex_request_errors_total",
|
||||
"Total number of failed proxy requests"
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
use crate::state::CortexState;
|
||||
use chrono::Utc;
|
||||
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
|
||||
use cortex_core::harness::ModelInfo;
|
||||
use cortex_core::node::{ModelEntry, ModelStatus};
|
||||
use std::sync::Arc;
|
||||
@@ -26,59 +25,7 @@ pub async fn poll_once(fleet: &CortexState) {
|
||||
}
|
||||
}
|
||||
|
||||
/// One-shot fetch of `GET /discovery`. Cached on the NodeState forever
|
||||
/// after the first success — topology is invariant for a given neuron
|
||||
/// process. Skipped when the cache is already populated.
|
||||
async fn maybe_poll_discovery(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||
{
|
||||
let nodes = fleet.nodes.read().await;
|
||||
match nodes.get(name) {
|
||||
Some(n) if n.discovery.is_some() => return,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
let url = format!("{endpoint}/discovery");
|
||||
let resp = match fleet
|
||||
.http_client
|
||||
.get(&url)
|
||||
.timeout(Duration::from_secs(5))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) if r.status().is_success() => r,
|
||||
Ok(r) => {
|
||||
tracing::debug!(node = name, status = %r.status(), "discovery probe non-success");
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!(node = name, error = %e, "discovery probe unreachable");
|
||||
return;
|
||||
}
|
||||
};
|
||||
match resp.json::<DiscoveryResponse>().await {
|
||||
Ok(d) => {
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
if let Some(node) = nodes.get_mut(name) {
|
||||
tracing::info!(
|
||||
node = name,
|
||||
hostname = %d.hostname,
|
||||
devices = d.devices.len(),
|
||||
"discovery cached"
|
||||
);
|
||||
node.discovery = Some(d);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(node = name, error = %e, "failed to parse /discovery response");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||
// Topology first — cheap once cached, and the router needs it to
|
||||
// route requests against catalogue entries that aren't loaded yet.
|
||||
maybe_poll_discovery(fleet, name, endpoint).await;
|
||||
|
||||
let url = format!("{endpoint}/models");
|
||||
|
||||
let result = fleet
|
||||
@@ -107,14 +54,12 @@ async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||
.and_modify(|e| {
|
||||
e.status = status;
|
||||
e.vram_estimate_mb = upstream.vram_used_mb;
|
||||
e.capabilities = upstream.capabilities.clone();
|
||||
})
|
||||
.or_insert_with(|| ModelEntry {
|
||||
id: upstream.id.clone(),
|
||||
status,
|
||||
last_accessed: None,
|
||||
vram_estimate_mb: upstream.vram_used_mb,
|
||||
capabilities: upstream.capabilities.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -144,51 +89,6 @@ async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||
node.healthy = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Release the write lock before the next HTTP call.
|
||||
drop(nodes);
|
||||
|
||||
// Poll /health for the activation snapshot. We don't want this to
|
||||
// flip the node to unhealthy on its own — a neuron that's serving
|
||||
// /models fine is still operational even if /health is briefly
|
||||
// unavailable — so failures are debug-level and leave the existing
|
||||
// activation reading in place.
|
||||
poll_health(fleet, name, endpoint).await;
|
||||
}
|
||||
|
||||
/// Fetch `/health` and stash the activation snapshot on NodeState.
|
||||
/// Decoupled from the /models poll so a /health glitch doesn't mark
|
||||
/// the neuron unhealthy or evict the model list.
|
||||
async fn poll_health(fleet: &CortexState, name: &str, endpoint: &str) {
|
||||
let url = format!("{endpoint}/health");
|
||||
let resp = match fleet
|
||||
.http_client
|
||||
.get(&url)
|
||||
.timeout(Duration::from_secs(5))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) if r.status().is_success() => r,
|
||||
Ok(r) => {
|
||||
tracing::debug!(node = name, status = %r.status(), "/health probe non-success");
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!(node = name, error = %e, "/health probe failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
match resp.json::<HealthResponse>().await {
|
||||
Ok(h) => {
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
if let Some(node) = nodes.get_mut(name) {
|
||||
node.activation = Some(h.activation);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!(node = name, error = %e, "failed to parse /health response");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_status(s: &str) -> ModelStatus {
|
||||
@@ -196,8 +96,6 @@ fn parse_status(s: &str) -> ModelStatus {
|
||||
"loaded" => ModelStatus::Loaded,
|
||||
"unloaded" => ModelStatus::Unloaded,
|
||||
"reloading" => ModelStatus::Reloading,
|
||||
"loading" => ModelStatus::Loading,
|
||||
"recovering" => ModelStatus::Recovering,
|
||||
_ => ModelStatus::Loaded,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//! Streaming HTTP reverse proxy to neuron backends.
|
||||
//! Streaming HTTP reverse proxy to mistral.rs backends.
|
||||
//!
|
||||
//! For streaming requests, SSE chunks are forwarded as they arrive.
|
||||
//! The proxy captures timing information for metrics but does not
|
||||
@@ -9,30 +9,16 @@ use anyhow::Result;
|
||||
use axum::body::Body;
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use futures::Stream;
|
||||
use futures::stream::BoxStream;
|
||||
use reqwest::Client;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Instant;
|
||||
|
||||
/// Proxy a request body to the resolved backend node and stream the response.
|
||||
///
|
||||
/// Logging contract: every call emits exactly one structured event at
|
||||
/// info / warn level for operator visibility, regardless of outcome.
|
||||
/// Network-level failures and non-2xx upstream statuses are warn'd here
|
||||
/// (closest to the wire); the user-facing response carries only the
|
||||
/// status code and a generic message — implementation detail (body,
|
||||
/// error chain) lives in the log, never in the API surface.
|
||||
pub async fn forward_request(
|
||||
client: &Client,
|
||||
route: &RouteDecision,
|
||||
path: &str,
|
||||
headers: HeaderMap,
|
||||
body: bytes::Bytes,
|
||||
model_id: &str,
|
||||
) -> Result<Response, ProxyError> {
|
||||
let request_start = Instant::now();
|
||||
let url = format!("{}{}", route.endpoint, path);
|
||||
tracing::info!(
|
||||
node = %route.node_name,
|
||||
@@ -51,39 +37,13 @@ pub async fn forward_request(
|
||||
req_builder = req_builder.header(key, value);
|
||||
}
|
||||
|
||||
let upstream_resp = match req_builder.send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
node = %route.node_name,
|
||||
url = %url,
|
||||
error = %e,
|
||||
"proxy: upstream request failed (network)"
|
||||
);
|
||||
return Err(ProxyError::Upstream(e));
|
||||
}
|
||||
};
|
||||
let upstream_resp = req_builder.send().await.map_err(ProxyError::Upstream)?;
|
||||
|
||||
let upstream_status = upstream_resp.status();
|
||||
if !upstream_status.is_success() {
|
||||
// Streaming body — can't snippet without breaking the stream
|
||||
// pass-through. Log status + URL; the client still gets the
|
||||
// upstream status, just without the leaked body.
|
||||
tracing::warn!(
|
||||
node = %route.node_name,
|
||||
url = %url,
|
||||
status = upstream_status.as_u16(),
|
||||
"proxy: upstream returned non-2xx"
|
||||
);
|
||||
}
|
||||
|
||||
let status = StatusCode::from_u16(upstream_status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
let status =
|
||||
StatusCode::from_u16(upstream_resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
|
||||
let resp_headers = upstream_resp.headers().clone();
|
||||
let stream = TokenMetricsStream::new(
|
||||
Box::pin(upstream_resp.bytes_stream()),
|
||||
TokenMetrics::new(model_id, &route.node_name, request_start),
|
||||
);
|
||||
let stream = upstream_resp.bytes_stream();
|
||||
|
||||
let body = Body::from_stream(stream);
|
||||
|
||||
@@ -92,261 +52,31 @@ pub async fn forward_request(
|
||||
response = response.header(key, value);
|
||||
}
|
||||
|
||||
response.body(body).map_err(|e| {
|
||||
tracing::warn!(
|
||||
node = %route.node_name,
|
||||
url = %url,
|
||||
error = %e,
|
||||
"proxy: failed to build response"
|
||||
);
|
||||
ProxyError::ResponseBuild(e.to_string())
|
||||
})
|
||||
response
|
||||
.body(body)
|
||||
.map_err(|e| ProxyError::ResponseBuild(e.to_string()))
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ProxyError {
|
||||
#[error("upstream request failed")]
|
||||
#[error("upstream request failed: {0}")]
|
||||
Upstream(reqwest::Error),
|
||||
#[error("failed to build response")]
|
||||
#[error("failed to build response: {0}")]
|
||||
ResponseBuild(String),
|
||||
}
|
||||
|
||||
impl IntoResponse for ProxyError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, message) = match &self {
|
||||
ProxyError::Upstream(_) => (StatusCode::BAD_GATEWAY, "upstream request failed"),
|
||||
ProxyError::ResponseBuild(_) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"failed to build response",
|
||||
),
|
||||
let status = match &self {
|
||||
ProxyError::Upstream(_) => StatusCode::BAD_GATEWAY,
|
||||
ProxyError::ResponseBuild(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
let body = serde_json::json!({
|
||||
"error": {
|
||||
"message": message,
|
||||
"message": self.to_string(),
|
||||
"type": "proxy_error",
|
||||
}
|
||||
});
|
||||
(status, axum::Json(body)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Per-request token metrics (#21) ─────────────────────────────────
|
||||
//
|
||||
// The proxy never buffers or re-serialises the upstream body — chunks
|
||||
// are forwarded verbatim. For metrics it observes each chunk's arrival
|
||||
// time and keeps a bounded tail of the body text, from which the final
|
||||
// OpenAI `usage` object (present on the last SSE chunk and on
|
||||
// non-streaming JSON bodies alike) yields engine-truth token counts.
|
||||
//
|
||||
// Emitted per request, labelled {model, node}:
|
||||
// cortex_time_to_first_token_seconds (histogram) — first body chunk
|
||||
// cortex_tokens_per_second (histogram) — completion tokens
|
||||
// over the decode window (first→last chunk); falls back to the
|
||||
// full request duration for single-chunk (non-streaming) bodies
|
||||
// cortex_prompt_tokens_total / cortex_completion_tokens_total (counters)
|
||||
|
||||
/// Cap on the retained body tail. The usage object rides on the final
|
||||
/// chunk, so a generous tail is plenty; the cap bounds memory on huge
|
||||
/// non-streaming bodies.
|
||||
const TAIL_CAP_BYTES: usize = 64 * 1024;
|
||||
|
||||
/// Find the value of the LAST `"key": <integer>` occurrence in `tail`.
|
||||
/// Pure and chunk-boundary-safe (the tail is contiguous appended text).
|
||||
/// The quoted-needle form means `completion_tokens` never matches
|
||||
/// `completion_tokens_details`.
|
||||
pub(crate) fn last_count_for(tail: &str, key: &str) -> Option<u64> {
|
||||
let needle = format!("\"{key}\"");
|
||||
let mut result = None;
|
||||
for (idx, _) in tail.match_indices(&needle) {
|
||||
let rest = tail[idx + needle.len()..].trim_start();
|
||||
let Some(rest) = rest.strip_prefix(':') else {
|
||||
continue;
|
||||
};
|
||||
let rest = rest.trim_start();
|
||||
let digits: &str = &rest[..rest
|
||||
.char_indices()
|
||||
.find(|(_, c)| !c.is_ascii_digit())
|
||||
.map(|(i, _)| i)
|
||||
.unwrap_or(rest.len())];
|
||||
if let Ok(v) = digits.parse::<u64>() {
|
||||
result = Some(v);
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
struct TokenMetrics {
|
||||
labels: [(&'static str, String); 2],
|
||||
request_start: Instant,
|
||||
first_chunk: Option<Instant>,
|
||||
last_chunk: Option<Instant>,
|
||||
tail: String,
|
||||
finished: bool,
|
||||
}
|
||||
|
||||
impl TokenMetrics {
|
||||
fn new(model_id: &str, node_name: &str, request_start: Instant) -> Self {
|
||||
Self {
|
||||
labels: [
|
||||
("model", model_id.to_string()),
|
||||
("node", node_name.to_string()),
|
||||
],
|
||||
request_start,
|
||||
first_chunk: None,
|
||||
last_chunk: None,
|
||||
tail: String::new(),
|
||||
finished: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn observe(&mut self, chunk: &[u8]) {
|
||||
let now = Instant::now();
|
||||
self.first_chunk.get_or_insert(now);
|
||||
self.last_chunk = Some(now);
|
||||
self.tail.push_str(&String::from_utf8_lossy(chunk));
|
||||
if self.tail.len() > TAIL_CAP_BYTES {
|
||||
// Keep the newest half; the usage object is always at the
|
||||
// very end of the body. Split at a char boundary.
|
||||
let mut cut = self.tail.len() - TAIL_CAP_BYTES / 2;
|
||||
while !self.tail.is_char_boundary(cut) {
|
||||
cut += 1;
|
||||
}
|
||||
self.tail.drain(..cut);
|
||||
}
|
||||
}
|
||||
|
||||
/// Emit the metrics exactly once — called on clean stream end and
|
||||
/// from Drop (client disconnect mid-stream still records what we
|
||||
/// saw).
|
||||
fn finish(&mut self) {
|
||||
if self.finished {
|
||||
return;
|
||||
}
|
||||
self.finished = true;
|
||||
let Some(first) = self.first_chunk else {
|
||||
return; // no body ever arrived — nothing to record
|
||||
};
|
||||
let ttft = first.duration_since(self.request_start).as_secs_f64();
|
||||
metrics::histogram!("cortex_time_to_first_token_seconds", &self.labels).record(ttft);
|
||||
|
||||
if let Some(prompt) = last_count_for(&self.tail, "prompt_tokens") {
|
||||
metrics::counter!("cortex_prompt_tokens_total", &self.labels).increment(prompt);
|
||||
}
|
||||
let Some(completion) = last_count_for(&self.tail, "completion_tokens") else {
|
||||
return;
|
||||
};
|
||||
if completion == 0 {
|
||||
return;
|
||||
}
|
||||
metrics::counter!("cortex_completion_tokens_total", &self.labels).increment(completion);
|
||||
|
||||
let last = self.last_chunk.unwrap_or(first);
|
||||
let decode_window = last.duration_since(first).as_secs_f64();
|
||||
// Streaming: rate over the decode window (first→last chunk).
|
||||
// Non-streaming bodies arrive as ~one chunk (window ≈ 0), where
|
||||
// the only honest denominator is the full request duration.
|
||||
let secs = if decode_window >= 0.1 {
|
||||
decode_window
|
||||
} else {
|
||||
last.duration_since(self.request_start).as_secs_f64()
|
||||
};
|
||||
if secs > 0.0 {
|
||||
metrics::histogram!("cortex_tokens_per_second", &self.labels)
|
||||
.record(completion as f64 / secs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Pass-through stream wrapper that feeds [`TokenMetrics`]. Emits on
|
||||
/// clean end-of-stream; the Drop impl covers client disconnects.
|
||||
struct TokenMetricsStream {
|
||||
inner: BoxStream<'static, Result<bytes::Bytes, reqwest::Error>>,
|
||||
metrics: TokenMetrics,
|
||||
}
|
||||
|
||||
impl TokenMetricsStream {
|
||||
fn new(
|
||||
inner: BoxStream<'static, Result<bytes::Bytes, reqwest::Error>>,
|
||||
metrics: TokenMetrics,
|
||||
) -> Self {
|
||||
Self { inner, metrics }
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for TokenMetricsStream {
|
||||
type Item = Result<bytes::Bytes, reqwest::Error>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
match this.inner.as_mut().poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(chunk))) => {
|
||||
this.metrics.observe(&chunk);
|
||||
Poll::Ready(Some(Ok(chunk)))
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(None) => {
|
||||
this.metrics.finish();
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TokenMetricsStream {
|
||||
fn drop(&mut self) {
|
||||
self.metrics.finish();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::last_count_for;
|
||||
|
||||
#[test]
|
||||
fn extracts_counts_from_final_sse_usage_chunk() {
|
||||
let tail = concat!(
|
||||
"data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n",
|
||||
"data: {\"choices\":[],\"usage\":{\"prompt_tokens\":225,",
|
||||
"\"completion_tokens\":42,\"total_tokens\":267}}\n\n",
|
||||
"data: [DONE]\n\n"
|
||||
);
|
||||
assert_eq!(last_count_for(tail, "prompt_tokens"), Some(225));
|
||||
assert_eq!(last_count_for(tail, "completion_tokens"), Some(42));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extracts_counts_from_non_streaming_body() {
|
||||
let tail = "{\"choices\":[{\"message\":{\"content\":\"hi\"}}],\
|
||||
\"usage\":{\"prompt_tokens\": 12, \"completion_tokens\": 7}}";
|
||||
assert_eq!(last_count_for(tail, "prompt_tokens"), Some(12));
|
||||
assert_eq!(last_count_for(tail, "completion_tokens"), Some(7));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignores_details_variants_and_takes_last_occurrence() {
|
||||
// completion_tokens_details must not shadow completion_tokens,
|
||||
// and the LAST usage object wins (matters when content echoes
|
||||
// a usage-shaped string earlier in the stream).
|
||||
let tail = concat!(
|
||||
"data: {\"usage\":{\"completion_tokens\":1}}\n\n",
|
||||
"data: {\"usage\":{\"completion_tokens\":99,",
|
||||
"\"completion_tokens_details\":{\"reasoning_tokens\":3}}}\n\n"
|
||||
);
|
||||
assert_eq!(last_count_for(tail, "completion_tokens"), Some(99));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn absent_keys_yield_none() {
|
||||
assert_eq!(
|
||||
last_count_for("data: [DONE]\n\n", "completion_tokens"),
|
||||
None
|
||||
);
|
||||
assert_eq!(last_count_for("", "prompt_tokens"), None);
|
||||
// key present but non-numeric value
|
||||
assert_eq!(
|
||||
last_count_for("\"completion_tokens\": null", "completion_tokens"),
|
||||
None
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,21 +2,13 @@
|
||||
//!
|
||||
//! Given a model ID from an inbound request, determine which node should
|
||||
//! handle it. Priority:
|
||||
//! 1. Node where the model is currently `Loaded` → use it.
|
||||
//! 2. Node where the model is `Unloaded` → use it; neuron's existing
|
||||
//! lazy-load behaviour will reload before serving the request.
|
||||
//! 3. Model is in the catalogue → pick a feasible neuron, call
|
||||
//! `POST /models/load`, wait for the load to complete, then
|
||||
//! proxy. First-request cold-load latency is acceptable per the
|
||||
//! unified-endpoint contract.
|
||||
//! 4. Not in catalogue, not loaded anywhere → 404.
|
||||
//! 1. Node where the model is currently `Loaded`
|
||||
//! 2. Node where the model is `Unloaded` (will lazy-load on request)
|
||||
//! 3. Error: model not found on any node
|
||||
|
||||
use crate::state::CortexState;
|
||||
use cortex_core::catalogue::ModelProfile;
|
||||
use cortex_core::harness::ModelSpec;
|
||||
use cortex_core::node::ModelStatus;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// The routing decision: which node endpoint to proxy the request to.
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -24,345 +16,62 @@ pub struct RouteDecision {
|
||||
pub node_name: String,
|
||||
/// The inference endpoint to proxy to (from neuron's /models/{id}/endpoint).
|
||||
pub endpoint: String,
|
||||
/// Whether the model will need to load (cold start). Set to true
|
||||
/// when we proxied to an `Unloaded` node (lazy load on neuron) or
|
||||
/// when we just triggered an explicit cold-load via the catalogue
|
||||
/// path.
|
||||
/// Whether the model will need to load (cold start).
|
||||
pub cold_start: bool,
|
||||
/// The concrete model id we actually routed to. Equal to the
|
||||
/// caller's requested id unless an alias was resolved (e.g. caller
|
||||
/// asked for `helexa/small`, this carries `Qwen/Qwen3-1.7B`). The
|
||||
/// handler uses this to rewrite the request body's `model` field
|
||||
/// before proxying — neurons reject requests where the body's
|
||||
/// model name doesn't match a loaded model.
|
||||
pub resolved_model_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RouteError {
|
||||
#[error("model '{0}' not found on any node and not in catalogue")]
|
||||
#[error("model '{0}' not found on any node")]
|
||||
ModelNotFound(String),
|
||||
#[error("no healthy nodes available")]
|
||||
NoHealthyNodes,
|
||||
#[error("failed to resolve inference endpoint for model '{0}' on node '{1}'")]
|
||||
EndpointResolveFailed(String, String),
|
||||
#[error(
|
||||
"model '{model_id}' is in the catalogue but no healthy neuron's topology satisfies its constraints"
|
||||
)]
|
||||
NoFeasibleNeuron { model_id: String },
|
||||
#[error("cold-load of '{model_id}' on '{node}' failed: {message}")]
|
||||
ColdLoadFailed {
|
||||
model_id: String,
|
||||
node: String,
|
||||
message: String,
|
||||
},
|
||||
#[error(
|
||||
"model '{model_id}' is recovering on node '{node}' (device context rebuild in progress) — retry shortly"
|
||||
)]
|
||||
ModelRecovering { model_id: String, node: String },
|
||||
}
|
||||
|
||||
impl RouteError {
|
||||
/// HTTP status the gateway should answer with. `ModelRecovering`
|
||||
/// is the one transient case (503, retry the same request);
|
||||
/// everything else keeps the long-standing 404 behaviour.
|
||||
pub fn http_status(&self) -> u16 {
|
||||
match self {
|
||||
RouteError::ModelRecovering { .. } => 503,
|
||||
_ => 404,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve which node should serve a request for the given model.
|
||||
/// Asks the neuron for the inference endpoint after selecting a node.
|
||||
pub async fn resolve(
|
||||
fleet: &Arc<CortexState>,
|
||||
requested_model_id: &str,
|
||||
model_id: &str,
|
||||
) -> Result<RouteDecision, RouteError> {
|
||||
// Alias resolution first — swap `helexa/small` (etc.) for the
|
||||
// concrete id before any node lookups so the rest of routing,
|
||||
// loading, and metrics deal in concrete ids only. `resolve_alias`
|
||||
// returns the input verbatim when it isn't an alias.
|
||||
let model_id = fleet.catalogue.resolve_alias(requested_model_id);
|
||||
if model_id != requested_model_id {
|
||||
tracing::debug!(
|
||||
requested = requested_model_id,
|
||||
resolved = model_id,
|
||||
"alias resolved"
|
||||
);
|
||||
}
|
||||
// Snapshot loaded / unloaded / recovering state from the poller cache.
|
||||
let (loaded_route, unloaded_route, recovering_node, any_healthy) = {
|
||||
let (node_name, neuron_endpoint, cold_start) = {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let mut loaded_route = None;
|
||||
let mut unloaded_route = None;
|
||||
let mut recovering_node = None;
|
||||
let mut any_healthy = false;
|
||||
|
||||
let mut loaded_candidate = None;
|
||||
let mut unloaded_candidate = None;
|
||||
|
||||
for node in nodes.values() {
|
||||
if !node.healthy {
|
||||
continue;
|
||||
}
|
||||
any_healthy = true;
|
||||
if let Some(entry) = node.models.get(model_id) {
|
||||
match entry.status {
|
||||
ModelStatus::Loaded | ModelStatus::Reloading => {
|
||||
loaded_route = Some((node.name.clone(), node.endpoint.clone(), false));
|
||||
loaded_candidate = Some((node.name.clone(), node.endpoint.clone(), false));
|
||||
break;
|
||||
}
|
||||
ModelStatus::Unloaded => {
|
||||
if unloaded_route.is_none() {
|
||||
unloaded_route = Some((node.name.clone(), node.endpoint.clone(), true));
|
||||
if unloaded_candidate.is_none() {
|
||||
unloaded_candidate =
|
||||
Some((node.name.clone(), node.endpoint.clone(), true));
|
||||
}
|
||||
}
|
||||
// Auto-recovering (#17/#20): the model is rebuilding
|
||||
// its device context on this node. Hold the route —
|
||||
// answer "retry shortly" rather than 404, and do NOT
|
||||
// fall through to the catalogue cold-load, which
|
||||
// would race a second placement (and a second copy's
|
||||
// worth of VRAM) against the in-flight recovery.
|
||||
ModelStatus::Recovering => {
|
||||
if recovering_node.is_none() {
|
||||
recovering_node = Some(node.name.clone());
|
||||
}
|
||||
}
|
||||
// Loading is gateway-synthesised from neuron's
|
||||
// activation snapshot; it never appears on the
|
||||
// wire from neuron's `/models`. Skip — the model
|
||||
// isn't actually servable yet. The pre-existing
|
||||
// race (catalogue cold_load fires a parallel
|
||||
// /models/load against the in-flight load) is no
|
||||
// worse than before; fixing it needs neuron-side
|
||||
// in-flight tracking on /models/load itself.
|
||||
ModelStatus::Loading => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
(loaded_route, unloaded_route, recovering_node, any_healthy)
|
||||
};
|
||||
|
||||
if !any_healthy {
|
||||
return Err(RouteError::NoHealthyNodes);
|
||||
}
|
||||
|
||||
// Priority 1: already loaded.
|
||||
if let Some((node_name, neuron_endpoint, cold_start)) = loaded_route {
|
||||
return finish(fleet, &node_name, &neuron_endpoint, model_id, cold_start).await;
|
||||
}
|
||||
|
||||
// Priority 2: recovering somewhere — transient hold, not a reroute.
|
||||
if let Some(node) = recovering_node {
|
||||
return Err(RouteError::ModelRecovering {
|
||||
model_id: model_id.to_string(),
|
||||
node,
|
||||
});
|
||||
}
|
||||
|
||||
// Priority 3: known to neuron but unloaded (neuron's lazy load).
|
||||
if let Some((node_name, neuron_endpoint, cold_start)) = unloaded_route {
|
||||
return finish(fleet, &node_name, &neuron_endpoint, model_id, cold_start).await;
|
||||
}
|
||||
|
||||
// Priority 4: catalogue × topology cold-load.
|
||||
if let Some(profile) = fleet.catalogue.get(model_id) {
|
||||
let (node_name, neuron_endpoint) = pick_feasible_neuron(fleet, profile).await?;
|
||||
cold_load(fleet, &node_name, &neuron_endpoint, profile).await?;
|
||||
return finish(fleet, &node_name, &neuron_endpoint, model_id, true).await;
|
||||
}
|
||||
|
||||
Err(RouteError::ModelNotFound(model_id.to_string()))
|
||||
}
|
||||
|
||||
/// Pick a healthy neuron whose discovered topology satisfies the
|
||||
/// profile. Preference order:
|
||||
/// 1. A neuron from `profile.pinned_on` that is healthy + feasible.
|
||||
/// 2. Otherwise, any healthy + feasible neuron, stable by name.
|
||||
async fn pick_feasible_neuron(
|
||||
fleet: &Arc<CortexState>,
|
||||
profile: &ModelProfile,
|
||||
) -> Result<(String, String), RouteError> {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let mut candidates: Vec<(String, String, bool)> = Vec::new();
|
||||
for node in nodes.values() {
|
||||
if !node.healthy {
|
||||
continue;
|
||||
}
|
||||
let Some(disc) = node.discovery.as_ref() else {
|
||||
continue;
|
||||
};
|
||||
if !profile.is_feasible_on(&node.name, &disc.devices) {
|
||||
continue;
|
||||
}
|
||||
let pinned = profile.pinned_on.iter().any(|n| n == &node.name);
|
||||
candidates.push((node.name.clone(), node.endpoint.clone(), pinned));
|
||||
}
|
||||
candidates.sort_by(|a, b| {
|
||||
b.2.cmp(&a.2) // pinned first (true > false)
|
||||
.then(a.0.cmp(&b.0))
|
||||
});
|
||||
let pick = candidates.into_iter().next();
|
||||
pick.map(|(n, e, _)| (n, e))
|
||||
.ok_or_else(|| RouteError::NoFeasibleNeuron {
|
||||
model_id: profile.id.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Issue `POST {endpoint}/models/load` for this profile on this neuron,
|
||||
/// blocking until the load completes (neuron's load endpoint is
|
||||
/// synchronous — it returns 200 once VRAM is materialised). On success
|
||||
/// also inserts a `Loaded` entry into the local NodeState cache so the
|
||||
/// caller's subsequent endpoint lookup sees the new model without
|
||||
/// waiting for the next poll cycle.
|
||||
async fn cold_load(
|
||||
fleet: &Arc<CortexState>,
|
||||
node_name: &str,
|
||||
neuron_endpoint: &str,
|
||||
profile: &ModelProfile,
|
||||
) -> Result<(), RouteError> {
|
||||
let spec = profile_to_spec(fleet, node_name, profile).await;
|
||||
let url = format!("{neuron_endpoint}/models/load");
|
||||
tracing::info!(model = %profile.id, node = node_name, "cold-loading via /models/load");
|
||||
|
||||
// Generous timeout: a fresh download + safetensors mmap + device
|
||||
// copy for a 30B-class dense model can comfortably exceed 5 min on
|
||||
// a slow link. The HTTP client's own default already covers most
|
||||
// of this; pin a longer per-request bound just here.
|
||||
let resp = match fleet
|
||||
.http_client
|
||||
.post(&url)
|
||||
.timeout(Duration::from_secs(1800))
|
||||
.json(&spec)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
return Err(RouteError::ColdLoadFailed {
|
||||
model_id: profile.id.clone(),
|
||||
node: node_name.to_string(),
|
||||
message: format!("HTTP request failed: {e}"),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
// Neuron returns 400 "already loaded" when two concurrent
|
||||
// requests race the same model. Treat that as success — both
|
||||
// requests effectively achieved the same end state.
|
||||
if body.contains("already loaded") {
|
||||
tracing::info!(
|
||||
model = %profile.id,
|
||||
node = node_name,
|
||||
"cold-load saw 'already loaded' — treating as success"
|
||||
);
|
||||
} else {
|
||||
return Err(RouteError::ColdLoadFailed {
|
||||
model_id: profile.id.clone(),
|
||||
node: node_name.to_string(),
|
||||
message: format!("HTTP {status}: {body}"),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
tracing::info!(model = %profile.id, node = node_name, "cold-load returned 200");
|
||||
}
|
||||
|
||||
// Warm the cache: insert a Loaded ModelEntry so the next
|
||||
// resolve() finds the model without waiting for the poll loop.
|
||||
{
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
if let Some(node) = nodes.get_mut(node_name) {
|
||||
node.models.insert(
|
||||
profile.id.clone(),
|
||||
cortex_core::node::ModelEntry {
|
||||
id: profile.id.clone(),
|
||||
status: ModelStatus::Loaded,
|
||||
last_accessed: Some(chrono::Utc::now()),
|
||||
vram_estimate_mb: profile.vram_mb,
|
||||
capabilities: Vec::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Translate a `ModelProfile` to a `ModelSpec` neuron's /models/load
|
||||
/// accepts. Devices are picked from the neuron's discovered topology —
|
||||
/// the first `min_devices` indices that meet `min_device_vram_mb`.
|
||||
async fn profile_to_spec(
|
||||
fleet: &Arc<CortexState>,
|
||||
node_name: &str,
|
||||
profile: &ModelProfile,
|
||||
) -> ModelSpec {
|
||||
let devices = {
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let mut picked: Vec<u32> = Vec::new();
|
||||
if let Some(node) = nodes.get(node_name)
|
||||
&& let Some(disc) = &node.discovery
|
||||
{
|
||||
let min_vram = profile.min_device_vram_mb.unwrap_or(0);
|
||||
for d in &disc.devices {
|
||||
if d.vram_total_mb >= min_vram {
|
||||
picked.push(d.index);
|
||||
if picked.len() as u32 >= profile.min_devices {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if picked.is_empty() {
|
||||
// Fall back to a 0..min_devices default; pick_feasible_neuron
|
||||
// already verified the topology satisfies the constraints,
|
||||
// so this only fires if discovery raced or was lost.
|
||||
(0..profile.min_devices).collect()
|
||||
} else {
|
||||
picked
|
||||
}
|
||||
|
||||
loaded_candidate.or(unloaded_candidate).ok_or_else(|| {
|
||||
if nodes.values().any(|n| n.healthy) {
|
||||
RouteError::ModelNotFound(model_id.to_string())
|
||||
} else {
|
||||
RouteError::NoHealthyNodes
|
||||
}
|
||||
})?
|
||||
};
|
||||
|
||||
let tensor_parallel = if profile.min_devices > 1 {
|
||||
Some(profile.min_devices)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
ModelSpec {
|
||||
model_id: qualified_model_id(profile),
|
||||
harness: profile.harness.clone(),
|
||||
quant: profile.quant.clone(),
|
||||
tensor_parallel,
|
||||
devices: Some(devices),
|
||||
}
|
||||
}
|
||||
|
||||
/// Prefix the catalogue id with the scheme when one is declared, so
|
||||
/// neuron resolves the load against the right registry. Without this,
|
||||
/// a profile pointing at the helexa registry would resolve via
|
||||
/// neuron's `default_source` (typically `huggingface`) and fetch
|
||||
/// bytes from the wrong place. Profiles that omit `source` continue
|
||||
/// to pass the bare id through, preserving the pre-Phase-3 contract.
|
||||
///
|
||||
/// Stays at module scope (not nested in `profile_to_spec`) so the unit
|
||||
/// tests can exercise it without spinning up CortexState topology.
|
||||
fn qualified_model_id(profile: &ModelProfile) -> String {
|
||||
match profile.source.as_deref() {
|
||||
Some(scheme) if !scheme.is_empty() => format!("{scheme}:{}", profile.id),
|
||||
_ => profile.id.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve neuron's `/models/{id}/endpoint` to its inference URL and
|
||||
/// build the final `RouteDecision`. Shared by all three priority
|
||||
/// branches above.
|
||||
async fn finish(
|
||||
fleet: &Arc<CortexState>,
|
||||
node_name: &str,
|
||||
neuron_endpoint: &str,
|
||||
model_id: &str,
|
||||
cold_start: bool,
|
||||
) -> Result<RouteDecision, RouteError> {
|
||||
// Ask the neuron for the inference endpoint for this model.
|
||||
let endpoint_url = format!(
|
||||
"{}/models/{}/endpoint",
|
||||
neuron_endpoint,
|
||||
@@ -380,119 +89,13 @@ async fn finish(
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let raw = inference_endpoint.ok_or_else(|| {
|
||||
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.to_string())
|
||||
let endpoint = inference_endpoint.ok_or_else(|| {
|
||||
RouteError::EndpointResolveFailed(model_id.to_string(), node_name.clone())
|
||||
})?;
|
||||
|
||||
// Rewrite loopback inference URLs to use the configured neuron host.
|
||||
// Neuron's default bind_url is `http://localhost:13131` (it can't
|
||||
// reliably know its own externally-resolvable name). Cortex sees a
|
||||
// URL that's only meaningful from the neuron host's own perspective;
|
||||
// proxying directly to localhost from a different cortex host would
|
||||
// hit nothing. Keep neuron's port and path (a future harness could
|
||||
// serve inference on a different port than the management API), but
|
||||
// swap the host for the one in cortex.toml.
|
||||
let endpoint = rewrite_loopback_host(&raw, neuron_endpoint).unwrap_or(raw);
|
||||
|
||||
Ok(RouteDecision {
|
||||
node_name: node_name.to_string(),
|
||||
node_name,
|
||||
endpoint,
|
||||
cold_start,
|
||||
resolved_model_id: model_id.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// If `inference_url`'s host is a loopback name (localhost / 127.0.0.1 /
|
||||
/// 0.0.0.0 / ::1), return a copy with the host replaced by
|
||||
/// `neuron_endpoint`'s host. Otherwise return None and the caller falls
|
||||
/// back to the inference URL as-is.
|
||||
fn rewrite_loopback_host(inference_url: &str, neuron_endpoint: &str) -> Option<String> {
|
||||
let inf = url::Url::parse(inference_url).ok()?;
|
||||
let inf_host = inf.host_str()?;
|
||||
let is_loopback = matches!(inf_host, "localhost" | "127.0.0.1" | "0.0.0.0" | "::1");
|
||||
if !is_loopback {
|
||||
return None;
|
||||
}
|
||||
let neuron = url::Url::parse(neuron_endpoint).ok()?;
|
||||
let new_host = neuron.host_str()?;
|
||||
let mut out = inf.clone();
|
||||
out.set_host(Some(new_host)).ok()?;
|
||||
// url::Url::to_string normalises an empty path to "/", which then
|
||||
// breaks downstream callers that do format!("{endpoint}/v1/...")
|
||||
// and produce a double slash. The proxy URL is treated as a base
|
||||
// string that the caller appends paths to, so strip the trailing
|
||||
// slash here.
|
||||
let s = out.to_string();
|
||||
Some(s.trim_end_matches('/').to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{ModelProfile, qualified_model_id, rewrite_loopback_host};
|
||||
|
||||
fn bare_profile(id: &str, source: Option<&str>) -> ModelProfile {
|
||||
ModelProfile {
|
||||
id: id.into(),
|
||||
harness: "candle".into(),
|
||||
quant: None,
|
||||
vram_mb: None,
|
||||
min_devices: 1,
|
||||
min_device_vram_mb: None,
|
||||
pinned_on: vec![],
|
||||
source: source.map(String::from),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn qualified_id_passes_through_when_source_absent() {
|
||||
let p = bare_profile("Qwen/Qwen3-30B", None);
|
||||
assert_eq!(qualified_model_id(&p), "Qwen/Qwen3-30B");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn qualified_id_prefixes_when_source_set() {
|
||||
let p = bare_profile("Helexa/Qwen3.6-27B-Uncensored", Some("helexa"));
|
||||
assert_eq!(
|
||||
qualified_model_id(&p),
|
||||
"helexa:Helexa/Qwen3.6-27B-Uncensored"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn qualified_id_passes_through_when_source_is_empty_string() {
|
||||
// An empty scheme is treated as absent — neuron's default_source
|
||||
// substitution kicks in.
|
||||
let p = bare_profile("Qwen/Qwen3-30B", Some(""));
|
||||
assert_eq!(qualified_model_id(&p), "Qwen/Qwen3-30B");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rewrites_localhost_keeps_port_and_path() {
|
||||
let out = rewrite_loopback_host(
|
||||
"http://localhost:13131",
|
||||
"http://beast.hanzalova.internal:13131",
|
||||
);
|
||||
assert_eq!(
|
||||
out.as_deref(),
|
||||
Some("http://beast.hanzalova.internal:13131")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rewrites_loopback_with_distinct_inference_port() {
|
||||
let out = rewrite_loopback_host("http://127.0.0.1:8080", "http://beast.lan:13131");
|
||||
assert_eq!(out.as_deref(), Some("http://beast.lan:8080"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn leaves_non_loopback_alone() {
|
||||
let out = rewrite_loopback_host("http://other.host:1234", "http://beast.lan:13131");
|
||||
assert_eq!(out, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn malformed_inference_url_returns_none() {
|
||||
let out = rewrite_loopback_host("not a url", "http://beast.lan:13131");
|
||||
assert_eq!(out, None);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,8 +26,6 @@ impl CortexState {
|
||||
models: HashMap::new(),
|
||||
lifecycle_cycles: 0,
|
||||
last_poll: None,
|
||||
discovery: None,
|
||||
activation: None,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,268 +0,0 @@
|
||||
//! Alias resolution: a client request with `model: "helexa/small"`
|
||||
//! routes to the concrete model id (e.g. `Qwen/Qwen3-1.7B`), with the
|
||||
//! proxied request body rewritten so the upstream neuron sees a model
|
||||
//! name that matches its loaded handle.
|
||||
|
||||
mod common;
|
||||
|
||||
use cortex_core::config::{
|
||||
EvictionSettings, EvictionStrategy, GatewayConfig, GatewaySettings, NeuronEndpoint,
|
||||
};
|
||||
use cortex_core::node::{ModelEntry, ModelStatus};
|
||||
use cortex_gateway::state::CortexState;
|
||||
use serde_json::json;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
/// Write a `models.toml` with one alias to a unique temp path. Returns
|
||||
/// the path; the file persists for the test process and gets reaped by
|
||||
/// the OS at exit. Using $XDG_RUNTIME_DIR fallback for the temp dir
|
||||
/// keeps the file off shared /tmp on CI without pulling in tempfile.
|
||||
fn write_models_toml(alias: &str, target: &str) -> PathBuf {
|
||||
let contents = format!(
|
||||
r#"
|
||||
[aliases]
|
||||
"{alias}" = "{target}"
|
||||
"#
|
||||
);
|
||||
let mut path = std::env::temp_dir();
|
||||
let pid = std::process::id();
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
path.push(format!("cortex-test-models-{pid}-{now}.toml"));
|
||||
std::fs::write(&path, contents).expect("write temp models.toml");
|
||||
path
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_alias_resolves_in_chat_completions() {
|
||||
let mock_url = common::spawn_mock_neuron().await;
|
||||
let models_path = write_models_toml("helexa/small", "test-model");
|
||||
|
||||
let config = GatewayConfig {
|
||||
gateway: GatewaySettings {
|
||||
listen: "127.0.0.1:0".into(),
|
||||
metrics_listen: "127.0.0.1:0".into(),
|
||||
},
|
||||
eviction: EvictionSettings {
|
||||
strategy: EvictionStrategy::Lru,
|
||||
defrag_after_cycles: 0,
|
||||
},
|
||||
neurons: vec![NeuronEndpoint {
|
||||
name: "mock-node".into(),
|
||||
endpoint: mock_url,
|
||||
}],
|
||||
models_config: models_path.to_string_lossy().to_string(),
|
||||
};
|
||||
|
||||
let fleet = Arc::new(CortexState::from_config(&config));
|
||||
|
||||
// Seed the node as healthy with the concrete model loaded under
|
||||
// the target id. The poller doesn't run in this test; we just
|
||||
// populate state manually.
|
||||
{
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
let node = nodes.get_mut("mock-node").expect("node must exist");
|
||||
node.healthy = true;
|
||||
node.models.insert(
|
||||
"test-model".into(),
|
||||
ModelEntry {
|
||||
id: "test-model".into(),
|
||||
status: ModelStatus::Loaded,
|
||||
last_accessed: None,
|
||||
vram_estimate_mb: None,
|
||||
capabilities: Vec::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Sanity: the catalogue actually picked up the alias.
|
||||
assert_eq!(
|
||||
fleet.catalogue.resolve_alias("helexa/small"),
|
||||
"test-model",
|
||||
"alias should resolve to target id"
|
||||
);
|
||||
|
||||
// Spawn the gateway against this fleet.
|
||||
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let gateway_addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
let gateway_url = format!("http://{gateway_addr}");
|
||||
|
||||
// Send a chat completion against the alias. The mock backend
|
||||
// echoes back the `model` field it received — so a body whose
|
||||
// model wasn't rewritten would come back as "helexa/small", and a
|
||||
// properly-rewritten one as "test-model".
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(format!("{gateway_url}/v1/chat/completions"))
|
||||
.json(&json!({
|
||||
"model": "helexa/small",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("gateway should respond");
|
||||
|
||||
assert!(resp.status().is_success(), "gateway returned non-2xx");
|
||||
let body: serde_json::Value = resp.json().await.expect("response is JSON");
|
||||
assert_eq!(
|
||||
body.get("model").and_then(|m| m.as_str()),
|
||||
Some("test-model"),
|
||||
"mock backend should have seen the resolved model id, not the alias"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_aliases_surface_in_v1_models() {
|
||||
let mock_url = common::spawn_mock_neuron().await;
|
||||
let models_path = write_models_toml("helexa/small", "test-model");
|
||||
|
||||
let config = GatewayConfig {
|
||||
gateway: GatewaySettings {
|
||||
listen: "127.0.0.1:0".into(),
|
||||
metrics_listen: "127.0.0.1:0".into(),
|
||||
},
|
||||
eviction: EvictionSettings {
|
||||
strategy: EvictionStrategy::Lru,
|
||||
defrag_after_cycles: 0,
|
||||
},
|
||||
neurons: vec![NeuronEndpoint {
|
||||
name: "mock-node".into(),
|
||||
endpoint: mock_url,
|
||||
}],
|
||||
models_config: models_path.to_string_lossy().to_string(),
|
||||
};
|
||||
|
||||
let fleet = Arc::new(CortexState::from_config(&config));
|
||||
|
||||
// Seed the target as loaded so the alias's mirrored entry shows
|
||||
// loaded=true.
|
||||
{
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
let node = nodes.get_mut("mock-node").expect("node must exist");
|
||||
node.healthy = true;
|
||||
node.models.insert(
|
||||
"test-model".into(),
|
||||
ModelEntry {
|
||||
id: "test-model".into(),
|
||||
status: ModelStatus::Loaded,
|
||||
last_accessed: None,
|
||||
vram_estimate_mb: Some(2000),
|
||||
capabilities: Vec::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let gateway_addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
let gateway_url = format!("http://{gateway_addr}");
|
||||
|
||||
let resp = reqwest::get(format!("{gateway_url}/v1/models"))
|
||||
.await
|
||||
.expect("gateway should respond");
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
let entries = body
|
||||
.get("data")
|
||||
.and_then(|d| d.as_array())
|
||||
.expect("data array");
|
||||
|
||||
// Both the alias and the target should be present.
|
||||
let ids: Vec<&str> = entries
|
||||
.iter()
|
||||
.filter_map(|e| e.get("id").and_then(|v| v.as_str()))
|
||||
.collect();
|
||||
assert!(ids.contains(&"test-model"), "target should be listed");
|
||||
assert!(ids.contains(&"helexa/small"), "alias should be listed");
|
||||
|
||||
// The alias's `loaded` flag and locations should mirror the target.
|
||||
let alias_entry = entries
|
||||
.iter()
|
||||
.find(|e| e.get("id").and_then(|v| v.as_str()) == Some("helexa/small"))
|
||||
.expect("alias entry");
|
||||
assert_eq!(alias_entry.get("loaded"), Some(&json!(true)));
|
||||
let locations = alias_entry
|
||||
.get("locations")
|
||||
.and_then(|l| l.as_array())
|
||||
.expect("locations array");
|
||||
assert_eq!(locations.len(), 1);
|
||||
assert_eq!(
|
||||
locations[0].get("node").and_then(|n| n.as_str()),
|
||||
Some("mock-node")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_alias_falls_through_for_unmapped_model() {
|
||||
// Catalogue has an alias for some-other-thing but the request
|
||||
// model "test-model" isn't an alias; resolution should be a no-op.
|
||||
let mock_url = common::spawn_mock_neuron().await;
|
||||
let models_path = write_models_toml("helexa/large", "definitely-not-loaded");
|
||||
|
||||
let config = GatewayConfig {
|
||||
gateway: GatewaySettings {
|
||||
listen: "127.0.0.1:0".into(),
|
||||
metrics_listen: "127.0.0.1:0".into(),
|
||||
},
|
||||
eviction: EvictionSettings {
|
||||
strategy: EvictionStrategy::Lru,
|
||||
defrag_after_cycles: 0,
|
||||
},
|
||||
neurons: vec![NeuronEndpoint {
|
||||
name: "mock-node".into(),
|
||||
endpoint: mock_url,
|
||||
}],
|
||||
models_config: models_path.to_string_lossy().to_string(),
|
||||
};
|
||||
|
||||
let fleet = Arc::new(CortexState::from_config(&config));
|
||||
{
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
let node = nodes.get_mut("mock-node").expect("node must exist");
|
||||
node.healthy = true;
|
||||
node.models.insert(
|
||||
"test-model".into(),
|
||||
ModelEntry {
|
||||
id: "test-model".into(),
|
||||
status: ModelStatus::Loaded,
|
||||
last_accessed: None,
|
||||
vram_estimate_mb: None,
|
||||
capabilities: Vec::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let gateway_addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
let gateway_url = format!("http://{gateway_addr}");
|
||||
|
||||
let resp = reqwest::Client::new()
|
||||
.post(format!("{gateway_url}/v1/chat/completions"))
|
||||
.json(&json!({
|
||||
"model": "test-model",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(resp.status().is_success());
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert_eq!(
|
||||
body.get("model").and_then(|m| m.as_str()),
|
||||
Some("test-model")
|
||||
);
|
||||
}
|
||||
@@ -123,124 +123,3 @@ async fn test_anthropic_invalid_request() {
|
||||
|
||||
assert_eq!(resp.status(), 400);
|
||||
}
|
||||
|
||||
/// #24: a streaming Anthropic request gets a translated Anthropic SSE
|
||||
/// stream — not raw OpenAI frames. Verifies the full event sequence,
|
||||
/// text reassembly, and the content type.
|
||||
#[tokio::test]
|
||||
async fn test_anthropic_streaming_sse_translation() {
|
||||
let mock_url =
|
||||
common::spawn_streaming_mock_neuron(4, std::time::Duration::from_millis(20)).await;
|
||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(format!("{gw_url}/v1/messages"))
|
||||
.header("content-type", "application/json")
|
||||
.json(&json!({
|
||||
"model": "test-model",
|
||||
"max_tokens": 64,
|
||||
"stream": true,
|
||||
"messages": [{"role": "user", "content": "Hi"}]
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(resp.status(), 200);
|
||||
assert!(
|
||||
resp.headers()
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.starts_with("text/event-stream"),
|
||||
"anthropic stream must be SSE"
|
||||
);
|
||||
|
||||
let body = resp.text().await.expect("stream should complete");
|
||||
assert!(
|
||||
!body.contains("chat.completion.chunk"),
|
||||
"raw OpenAI frames must not leak through:\n{body}"
|
||||
);
|
||||
|
||||
let event_names: Vec<&str> = body
|
||||
.lines()
|
||||
.filter_map(|l| l.strip_prefix("event: "))
|
||||
.collect();
|
||||
assert_eq!(
|
||||
event_names,
|
||||
vec![
|
||||
"message_start",
|
||||
"content_block_start",
|
||||
"content_block_delta",
|
||||
"content_block_delta",
|
||||
"content_block_delta",
|
||||
"content_block_delta",
|
||||
"content_block_stop",
|
||||
"message_delta",
|
||||
"message_stop",
|
||||
],
|
||||
"unexpected event sequence:\n{body}"
|
||||
);
|
||||
|
||||
// Reassemble the text deltas: the mock emits token0..token3.
|
||||
let text: String = body
|
||||
.lines()
|
||||
.filter_map(|l| l.strip_prefix("data: "))
|
||||
.filter_map(|d| serde_json::from_str::<serde_json::Value>(d).ok())
|
||||
.filter(|v| v["type"] == "content_block_delta")
|
||||
.filter_map(|v| v["delta"]["text"].as_str().map(String::from))
|
||||
.collect();
|
||||
assert_eq!(text, "token0token1token2token3");
|
||||
|
||||
// The mock sends no finish_reason — stop_reason defaults to
|
||||
// end_turn, and output_tokens falls back to the delta count.
|
||||
let message_delta = body
|
||||
.lines()
|
||||
.filter_map(|l| l.strip_prefix("data: "))
|
||||
.filter_map(|d| serde_json::from_str::<serde_json::Value>(d).ok())
|
||||
.find(|v| v["type"] == "message_delta")
|
||||
.expect("message_delta event present");
|
||||
assert_eq!(message_delta["delta"]["stop_reason"], "end_turn");
|
||||
assert_eq!(message_delta["usage"]["output_tokens"], 4);
|
||||
}
|
||||
|
||||
/// #24: an upstream usage frame (stream_options include_usage shape)
|
||||
/// rides into message_delta as input/output token counts.
|
||||
#[tokio::test]
|
||||
async fn test_anthropic_streaming_usage_propagation() {
|
||||
let mock_url = common::spawn_streaming_mock_neuron_with_usage(
|
||||
3,
|
||||
std::time::Duration::from_millis(10),
|
||||
225,
|
||||
42,
|
||||
)
|
||||
.await;
|
||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let body = client
|
||||
.post(format!("{gw_url}/v1/messages"))
|
||||
.header("content-type", "application/json")
|
||||
.json(&json!({
|
||||
"model": "test-model",
|
||||
"max_tokens": 64,
|
||||
"stream": true,
|
||||
"messages": [{"role": "user", "content": "Hi"}]
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("request should succeed")
|
||||
.text()
|
||||
.await
|
||||
.expect("stream should complete");
|
||||
|
||||
let message_delta = body
|
||||
.lines()
|
||||
.filter_map(|l| l.strip_prefix("data: "))
|
||||
.filter_map(|d| serde_json::from_str::<serde_json::Value>(d).ok())
|
||||
.find(|v| v["type"] == "message_delta")
|
||||
.expect("message_delta event present");
|
||||
assert_eq!(message_delta["usage"]["output_tokens"], 42);
|
||||
assert_eq!(message_delta["usage"]["input_tokens"], 225);
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ use tokio::net::TcpListener;
|
||||
/// - GET /models/:id/endpoint (returns the inference URL)
|
||||
/// - POST /models/unload (accepts unload requests)
|
||||
/// - GET /v1/chat/completions + POST /v1/chat/completions (inference)
|
||||
///
|
||||
/// Returns the neuron base URL.
|
||||
pub async fn spawn_mock_neuron() -> String {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
@@ -44,7 +43,6 @@ pub async fn spawn_mock_neuron() -> String {
|
||||
post(|Json(_body): Json<Value>| async { Json(json!({"status": "unloaded"})) }),
|
||||
)
|
||||
.route("/v1/chat/completions", post(mock_chat_completions))
|
||||
.route("/v1/responses", post(mock_responses))
|
||||
.route("/v1/models", get(mock_v1_models));
|
||||
|
||||
tokio::spawn(async move {
|
||||
@@ -56,7 +54,7 @@ pub async fn spawn_mock_neuron() -> String {
|
||||
|
||||
async fn mock_neuron_list_models() -> Json<Value> {
|
||||
Json(json!([
|
||||
{"id": "test-model", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": 8000}
|
||||
{"id": "test-model", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000}
|
||||
]))
|
||||
}
|
||||
|
||||
@@ -94,39 +92,6 @@ async fn mock_chat_completions(Json(body): Json<Value>) -> Json<Value> {
|
||||
}))
|
||||
}
|
||||
|
||||
async fn mock_responses(Json(body): Json<Value>) -> Json<Value> {
|
||||
let model = body
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
// Echo the model field back and synthesise a tiny ResponsesResponse.
|
||||
// Mirrors the shape neuron's /v1/responses handler emits so the
|
||||
// gateway test only needs to assert the proxy round-tripped it.
|
||||
Json(json!({
|
||||
"id": "resp-test-001",
|
||||
"object": "response",
|
||||
"created_at": 1700000000_u64,
|
||||
"status": "completed",
|
||||
"model": model,
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"id": "msg-test-001",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": "Hello from mock backend",
|
||||
"annotations": []
|
||||
}],
|
||||
"status": "completed"
|
||||
}],
|
||||
"usage": {
|
||||
"input_tokens": 5,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 10
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Spawns a mock neuron that returns SSE streaming responses for chat completions.
|
||||
pub async fn spawn_streaming_mock_neuron(chunk_count: usize, chunk_delay: Duration) -> String {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
@@ -196,120 +161,8 @@ pub async fn spawn_streaming_mock_neuron(chunk_count: usize, chunk_delay: Durati
|
||||
base_url
|
||||
}
|
||||
|
||||
/// Like `spawn_streaming_mock_neuron`, but the stream ends with an
|
||||
/// OpenAI `stream_options.include_usage`-style final chunk (empty
|
||||
/// choices + usage object) before `[DONE]` — the shape the gateway's
|
||||
/// token metrics (#21) extract counts from.
|
||||
pub async fn spawn_streaming_mock_neuron_with_usage(
|
||||
chunk_count: usize,
|
||||
chunk_delay: Duration,
|
||||
prompt_tokens: u64,
|
||||
completion_tokens: u64,
|
||||
) -> String {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let base_url = format!("http://{addr}");
|
||||
let inference_url = base_url.clone();
|
||||
|
||||
let app = Router::new()
|
||||
.route("/models", get(mock_neuron_list_models))
|
||||
.route(
|
||||
"/models/{model_id}/endpoint",
|
||||
get(move |Path(_model_id): Path<String>| {
|
||||
let url = inference_url.clone();
|
||||
async move { Json(json!({"url": url})) }
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/v1/chat/completions",
|
||||
post(move |Json(body): Json<Value>| async move {
|
||||
let model = body
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
let mut chunks: Vec<String> = (0..chunk_count)
|
||||
.map(|i| {
|
||||
let chunk = json!({
|
||||
"id": "chatcmpl-stream-002",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1700000000_u64,
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": { "content": format!("token{i}") },
|
||||
"finish_reason": null
|
||||
}]
|
||||
});
|
||||
format!("data: {chunk}\n\n")
|
||||
})
|
||||
.collect();
|
||||
let usage_chunk = json!({
|
||||
"id": "chatcmpl-stream-002",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1700000000_u64,
|
||||
"model": model,
|
||||
"choices": [],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
});
|
||||
chunks.push(format!("data: {usage_chunk}\n\n"));
|
||||
chunks.push("data: [DONE]\n\n".to_string());
|
||||
|
||||
let delay = chunk_delay;
|
||||
let stream = stream::iter(chunks).then(move |chunk| async move {
|
||||
tokio::time::sleep(delay).await;
|
||||
Ok::<_, std::convert::Infallible>(chunk)
|
||||
});
|
||||
|
||||
Response::builder()
|
||||
.header(header::CONTENT_TYPE, "text/event-stream")
|
||||
.header(header::CACHE_CONTROL, "no-cache")
|
||||
.body(Body::from_stream(stream))
|
||||
.unwrap()
|
||||
}),
|
||||
);
|
||||
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
|
||||
base_url
|
||||
}
|
||||
|
||||
/// Spawns a mock neuron with a custom models list.
|
||||
pub async fn spawn_mock_neuron_with_models(models_response: Value) -> String {
|
||||
spawn_mock_neuron_with_models_and_health(models_response, default_health_response()).await
|
||||
}
|
||||
|
||||
/// Default `/health` response used by mocks that don't care about the
|
||||
/// activation field — empty devices, no in-flight pre-warm, state=ready.
|
||||
pub fn default_health_response() -> Value {
|
||||
json!({
|
||||
"uptime_secs": 0,
|
||||
"devices": [],
|
||||
"activation": {
|
||||
"state": "ready",
|
||||
"pending": [],
|
||||
"in_progress": null,
|
||||
"completed": [],
|
||||
"failed": []
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Variant of `spawn_mock_neuron_with_models` that also serves a
|
||||
/// `/health` body. Used by tests that drive the gateway's activation
|
||||
/// surface (poller reading /health, /v1/models synthesising Loading
|
||||
/// locations from in_progress / pending).
|
||||
pub async fn spawn_mock_neuron_with_models_and_health(
|
||||
models_response: Value,
|
||||
health_response: Value,
|
||||
) -> String {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let base_url = format!("http://{addr}");
|
||||
@@ -323,13 +176,6 @@ pub async fn spawn_mock_neuron_with_models_and_health(
|
||||
async move { Json(resp) }
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/health",
|
||||
get(move || {
|
||||
let resp = health_response.clone();
|
||||
async move { Json(resp) }
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/models/{model_id}/endpoint",
|
||||
get(move |Path(_model_id): Path<String>| {
|
||||
@@ -390,7 +236,6 @@ pub async fn spawn_gateway_with_state(mock_url: &str) -> (Arc<CortexState>, Stri
|
||||
status: ModelStatus::Loaded,
|
||||
last_accessed: None,
|
||||
vram_estimate_mb: Some(8000),
|
||||
capabilities: Vec::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -91,7 +91,6 @@ async fn test_evict_lru_model() {
|
||||
status: ModelStatus::Loaded,
|
||||
last_accessed: Some(Utc::now() - chrono::Duration::hours(2)),
|
||||
vram_estimate_mb: Some(8000),
|
||||
capabilities: Vec::new(),
|
||||
},
|
||||
);
|
||||
node.models.insert(
|
||||
@@ -101,7 +100,6 @@ async fn test_evict_lru_model() {
|
||||
status: ModelStatus::Loaded,
|
||||
last_accessed: Some(Utc::now()),
|
||||
vram_estimate_mb: Some(8000),
|
||||
capabilities: Vec::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -165,7 +163,6 @@ async fn test_eviction_increments_lifecycle_cycles() {
|
||||
status: ModelStatus::Loaded,
|
||||
last_accessed: None,
|
||||
vram_estimate_mb: None,
|
||||
capabilities: Vec::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,26 +1,20 @@
|
||||
mod common;
|
||||
|
||||
use serde_json::json;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
/// The metrics recorder is a process-wide global; both tests in this
|
||||
/// binary run against one shared install. Assertions must therefore be
|
||||
/// order-independent (presence of names / monotonic counters, not
|
||||
/// "empty before").
|
||||
fn recorder() -> &'static metrics_exporter_prometheus::PrometheusHandle {
|
||||
static HANDLE: OnceLock<metrics_exporter_prometheus::PrometheusHandle> = OnceLock::new();
|
||||
HANDLE.get_or_init(|| {
|
||||
cortex_gateway::metrics::install_test_recorder().expect("recorder should install")
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_metrics_emitted_after_proxy() {
|
||||
let handle = recorder();
|
||||
let handle = cortex_gateway::metrics::install_test_recorder().expect("recorder should install");
|
||||
|
||||
let mock_url = common::spawn_mock_neuron().await;
|
||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||
|
||||
let before = handle.render();
|
||||
assert!(
|
||||
!before.contains("cortex_requests_total"),
|
||||
"no request metrics before any requests"
|
||||
);
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(format!("{gw_url}/v1/chat/completions"))
|
||||
@@ -50,72 +44,3 @@ async fn test_metrics_emitted_after_proxy() {
|
||||
"no errors expected for a successful request"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_token_metrics_emitted_for_streamed_request() {
|
||||
// #21: a streamed chat completion with a final usage chunk must
|
||||
// produce TTFT + tok/s histograms and prompt/completion token
|
||||
// counters, labelled with model and node. The recorder is global
|
||||
// per-process, so this test runs in its own binary invocation —
|
||||
// cargo's per-file integration binaries give us that as long as
|
||||
// only one test in this file installs the recorder... it isn't:
|
||||
// test_metrics_emitted_after_proxy also installs. Whichever wins
|
||||
// the race, both render from the same recorder, so assert on
|
||||
// delta-able names rather than exact totals.
|
||||
let handle = recorder();
|
||||
|
||||
let mock_url = common::spawn_streaming_mock_neuron_with_usage(
|
||||
5,
|
||||
std::time::Duration::from_millis(40),
|
||||
225,
|
||||
42,
|
||||
)
|
||||
.await;
|
||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(format!("{gw_url}/v1/chat/completions"))
|
||||
.header("content-type", "application/json")
|
||||
.json(&json!({
|
||||
"model": "test-model",
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"stream": true
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
assert_eq!(resp.status(), 200);
|
||||
let body = resp.text().await.expect("stream should complete");
|
||||
assert!(body.contains("[DONE]"));
|
||||
|
||||
let rendered = handle.render();
|
||||
for needle in [
|
||||
"cortex_time_to_first_token_seconds",
|
||||
"cortex_tokens_per_second",
|
||||
] {
|
||||
assert!(
|
||||
rendered.contains(needle),
|
||||
"{needle} should be present.\nMetrics:\n{rendered}"
|
||||
);
|
||||
}
|
||||
// The recorder is shared with the sibling test (same model/node
|
||||
// labels), so counters are lower bounds, not exact values: this
|
||||
// request contributed prompt=225 / completion=42.
|
||||
let counter_value = |name: &str| -> u64 {
|
||||
rendered
|
||||
.lines()
|
||||
.find(|l| l.starts_with(name) && l.contains(r#"model="test-model""#))
|
||||
.and_then(|l| l.rsplit(' ').next())
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or_else(|| panic!("{name} should be present.\nMetrics:\n{rendered}"))
|
||||
};
|
||||
assert!(
|
||||
counter_value("cortex_prompt_tokens_total") >= 225,
|
||||
"prompt token counter should include this request's 225.\nMetrics:\n{rendered}"
|
||||
);
|
||||
assert!(
|
||||
counter_value("cortex_completion_tokens_total") >= 42,
|
||||
"completion token counter should include this request's 42.\nMetrics:\n{rendered}"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -12,8 +12,8 @@ use std::sync::Arc;
|
||||
async fn test_poller_discovers_models() {
|
||||
// Mock neuron reports 2 models via /models endpoint (neuron format).
|
||||
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||
{"id": "model-a", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": 8000},
|
||||
{"id": "model-b", "harness": "candle", "status": "unloaded", "devices": [], "vram_used_mb": null}
|
||||
{"id": "model-a", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": 8000},
|
||||
{"id": "model-b", "harness": "mistralrs", "status": "unloaded", "devices": [], "vram_used_mb": null}
|
||||
]))
|
||||
.await;
|
||||
|
||||
@@ -63,8 +63,8 @@ async fn test_poller_discovers_models() {
|
||||
#[tokio::test]
|
||||
async fn test_poller_updates_gateway_models_endpoint() {
|
||||
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||
{"id": "model-x", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
||||
{"id": "model-y", "harness": "candle", "status": "loaded", "devices": [1], "vram_used_mb": null}
|
||||
{"id": "model-x", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
||||
{"id": "model-y", "harness": "mistralrs", "status": "loaded", "devices": [1], "vram_used_mb": null}
|
||||
]))
|
||||
.await;
|
||||
|
||||
@@ -118,87 +118,6 @@ async fn test_poller_updates_gateway_models_endpoint() {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_models_endpoint_unions_capabilities_across_nodes() {
|
||||
// C3: two neurons each have the same model loaded but advertise
|
||||
// different capability sets. The gateway's /v1/models must report
|
||||
// the union — a model loaded text-only on one node and
|
||||
// text+vision on another is vision-capable to the fleet.
|
||||
let node_a = common::spawn_mock_neuron_with_models(json!([
|
||||
{"id": "shared-model", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null, "capabilities": ["text"]}
|
||||
]))
|
||||
.await;
|
||||
let node_b = common::spawn_mock_neuron_with_models(json!([
|
||||
{"id": "shared-model", "harness": "candle", "status": "loaded", "devices": [1], "vram_used_mb": null, "capabilities": ["text", "vision"]}
|
||||
]))
|
||||
.await;
|
||||
|
||||
let config = GatewayConfig {
|
||||
gateway: GatewaySettings {
|
||||
listen: "127.0.0.1:0".into(),
|
||||
metrics_listen: "127.0.0.1:0".into(),
|
||||
},
|
||||
eviction: EvictionSettings {
|
||||
strategy: EvictionStrategy::Lru,
|
||||
defrag_after_cycles: 0,
|
||||
},
|
||||
neurons: vec![
|
||||
NeuronEndpoint {
|
||||
name: "node-a".into(),
|
||||
endpoint: node_a,
|
||||
},
|
||||
NeuronEndpoint {
|
||||
name: "node-b".into(),
|
||||
endpoint: node_b,
|
||||
},
|
||||
],
|
||||
models_config: "/dev/null".into(),
|
||||
};
|
||||
|
||||
let fleet = Arc::new(CortexState::from_config(&config));
|
||||
cortex_gateway::poller::poll_once(&fleet).await;
|
||||
|
||||
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let body: serde_json::Value = client
|
||||
.get(format!("http://{addr}/v1/models"))
|
||||
.send()
|
||||
.await
|
||||
.expect("request should succeed")
|
||||
.json()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let model = body["data"]
|
||||
.as_array()
|
||||
.expect("data array")
|
||||
.iter()
|
||||
.find(|m| m["id"] == "shared-model")
|
||||
.expect("shared-model should be present");
|
||||
|
||||
let caps: Vec<&str> = model["capabilities"]
|
||||
.as_array()
|
||||
.expect("capabilities array")
|
||||
.iter()
|
||||
.filter_map(|c| c.as_str())
|
||||
.collect();
|
||||
assert!(caps.contains(&"text"), "union must include text: {caps:?}");
|
||||
assert!(
|
||||
caps.contains(&"vision"),
|
||||
"union must include vision: {caps:?}"
|
||||
);
|
||||
assert_eq!(caps.len(), 2, "union must not duplicate text: {caps:?}");
|
||||
|
||||
// Both nodes hold the model, so two locations regardless of caps.
|
||||
assert_eq!(model["locations"].as_array().unwrap().len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_poller_marks_unreachable_node_unhealthy() {
|
||||
let config = GatewayConfig {
|
||||
@@ -233,8 +152,8 @@ async fn test_poller_marks_unreachable_node_unhealthy() {
|
||||
#[tokio::test]
|
||||
async fn test_poller_removes_stale_models() {
|
||||
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||
{"id": "keep-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
||||
{"id": "drop-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
||||
{"id": "keep-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null},
|
||||
{"id": "drop-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
||||
]))
|
||||
.await;
|
||||
|
||||
@@ -264,7 +183,7 @@ async fn test_poller_removes_stale_models() {
|
||||
|
||||
// New mock with only one model.
|
||||
let new_mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||
{"id": "keep-me", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
||||
{"id": "keep-me", "harness": "mistralrs", "status": "loaded", "devices": [0], "vram_used_mb": null}
|
||||
]))
|
||||
.await;
|
||||
|
||||
@@ -297,7 +216,6 @@ async fn test_poller_removes_stale_models() {
|
||||
status: ModelStatus::Loaded,
|
||||
last_accessed: None,
|
||||
vram_estimate_mb: None,
|
||||
capabilities: Vec::new(),
|
||||
},
|
||||
);
|
||||
node.models.insert(
|
||||
@@ -307,7 +225,6 @@ async fn test_poller_removes_stale_models() {
|
||||
status: ModelStatus::Loaded,
|
||||
last_accessed: None,
|
||||
vram_estimate_mb: None,
|
||||
capabilities: Vec::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -320,94 +237,3 @@ async fn test_poller_removes_stale_models() {
|
||||
assert!(node.models.contains_key("keep-me"));
|
||||
assert!(!node.models.contains_key("drop-me"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_poller_captures_activation_from_health() {
|
||||
// Mock neuron is mid-prewarm: /models reports nothing (the loading
|
||||
// model hasn't been inserted into the harness map yet), but
|
||||
// /health's activation says model-x is in_progress and model-y is
|
||||
// queued behind it.
|
||||
let mock_url = common::spawn_mock_neuron_with_models_and_health(
|
||||
json!([]),
|
||||
json!({
|
||||
"uptime_secs": 30,
|
||||
"devices": [],
|
||||
"activation": {
|
||||
"state": "pre_warming",
|
||||
"pending": ["Qwen/model-y"],
|
||||
"in_progress": "Qwen/model-x",
|
||||
"completed": [],
|
||||
"failed": []
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let config = GatewayConfig {
|
||||
gateway: GatewaySettings {
|
||||
listen: "127.0.0.1:0".into(),
|
||||
metrics_listen: "127.0.0.1:0".into(),
|
||||
},
|
||||
eviction: EvictionSettings {
|
||||
strategy: EvictionStrategy::Lru,
|
||||
defrag_after_cycles: 0,
|
||||
},
|
||||
neurons: vec![NeuronEndpoint {
|
||||
name: "prewarm-node".into(),
|
||||
endpoint: mock_url,
|
||||
}],
|
||||
models_config: "/dev/null".into(),
|
||||
};
|
||||
|
||||
let fleet = Arc::new(CortexState::from_config(&config));
|
||||
cortex_gateway::poller::poll_once(&fleet).await;
|
||||
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let node = nodes.get("prewarm-node").unwrap();
|
||||
assert!(node.healthy);
|
||||
// /models was empty — no entries in the per-node model map.
|
||||
assert!(node.models.is_empty());
|
||||
// But /health's activation should be captured.
|
||||
let activation = node
|
||||
.activation
|
||||
.as_ref()
|
||||
.expect("activation should be populated after /health poll");
|
||||
assert_eq!(activation.in_progress.as_deref(), Some("Qwen/model-x"));
|
||||
assert_eq!(activation.pending, vec!["Qwen/model-y".to_string()]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_poller_parses_recovering_status() {
|
||||
// #20: a model auto-recovering on a neuron (poisoned → unload →
|
||||
// reload, #17) is reported with status "recovering" and must land
|
||||
// in gateway state as the dedicated Recovering status — not fall
|
||||
// through the parser's catch-all to Loaded.
|
||||
let mock_url = common::spawn_mock_neuron_with_models(json!([
|
||||
{"id": "model-r", "harness": "candle", "status": "recovering", "devices": [0, 1], "vram_used_mb": null}
|
||||
]))
|
||||
.await;
|
||||
|
||||
let config = GatewayConfig {
|
||||
gateway: GatewaySettings {
|
||||
listen: "127.0.0.1:0".into(),
|
||||
metrics_listen: "127.0.0.1:0".into(),
|
||||
},
|
||||
eviction: EvictionSettings {
|
||||
strategy: EvictionStrategy::Lru,
|
||||
defrag_after_cycles: 0,
|
||||
},
|
||||
neurons: vec![NeuronEndpoint {
|
||||
name: "test-node".into(),
|
||||
endpoint: mock_url,
|
||||
}],
|
||||
models_config: "/dev/null".into(),
|
||||
};
|
||||
|
||||
let fleet = Arc::new(CortexState::from_config(&config));
|
||||
cortex_gateway::poller::poll_once(&fleet).await;
|
||||
|
||||
let nodes = fleet.nodes.read().await;
|
||||
let node = nodes.get("test-node").unwrap();
|
||||
let model_r = node.models.get("model-r").expect("model-r should exist");
|
||||
assert_eq!(model_r.status, ModelStatus::Recovering);
|
||||
}
|
||||
|
||||
@@ -171,64 +171,3 @@ async fn test_missing_model_field() {
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert!(body["error"]["message"].as_str().unwrap().contains("model"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_recovering_model_returns_503_and_stays_listed() {
|
||||
// #20: while a model auto-recovers on a neuron, the gateway must
|
||||
// hold the route — transient 503 ("retry shortly"), not the 404
|
||||
// "not found on any node" that makes a recovering model look
|
||||
// evicted — and keep listing it on /v1/models.
|
||||
let mock_url = common::spawn_mock_neuron().await;
|
||||
let (fleet, gw_url) = common::spawn_gateway_with_state(&mock_url).await;
|
||||
|
||||
{
|
||||
let mut nodes = fleet.nodes.write().await;
|
||||
let node = nodes.get_mut("mock-node").expect("node must exist");
|
||||
node.models.insert(
|
||||
"recovering-model".into(),
|
||||
cortex_core::node::ModelEntry {
|
||||
id: "recovering-model".into(),
|
||||
status: cortex_core::node::ModelStatus::Recovering,
|
||||
last_accessed: None,
|
||||
vram_estimate_mb: Some(8000),
|
||||
capabilities: Vec::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(format!("{gw_url}/v1/chat/completions"))
|
||||
.header("content-type", "application/json")
|
||||
.json(&json!({
|
||||
"model": "recovering-model",
|
||||
"messages": [{"role": "user", "content": "Hi"}]
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(resp.status(), 503);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
let message = body["error"]["message"].as_str().unwrap();
|
||||
assert!(
|
||||
message.contains("recovering") && message.contains("retry"),
|
||||
"503 body must say recovering/retry, got: {message}"
|
||||
);
|
||||
|
||||
// The model must still be visible on the unified models endpoint.
|
||||
let models: serde_json::Value = client
|
||||
.get(format!("{gw_url}/v1/models"))
|
||||
.send()
|
||||
.await
|
||||
.expect("models request should succeed")
|
||||
.json()
|
||||
.await
|
||||
.unwrap();
|
||||
let listed = models["data"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.any(|m| m["id"] == "recovering-model");
|
||||
assert!(listed, "recovering model must stay listed on /v1/models");
|
||||
}
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
//! Integration tests for the `/v1/responses` proxy route.
|
||||
//!
|
||||
//! The gateway forwards the request body to whichever neuron has the
|
||||
//! model loaded. These tests exercise the routing decision (200 on a
|
||||
//! known model, 404 on an unknown model, 400 on a missing model
|
||||
//! field) and confirm the response body round-trips verbatim.
|
||||
|
||||
mod common;
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
/// Happy path: gateway routes a `/v1/responses` request to the neuron
|
||||
/// that has the model loaded, and the neuron's response body
|
||||
/// arrives at the client unchanged.
|
||||
#[tokio::test]
|
||||
async fn test_responses_proxy() {
|
||||
let mock_url = common::spawn_mock_neuron().await;
|
||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(format!("{gw_url}/v1/responses"))
|
||||
.header("content-type", "application/json")
|
||||
.json(&json!({
|
||||
"model": "test-model",
|
||||
"input": "Hi"
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(resp.status(), 200);
|
||||
|
||||
let body: serde_json::Value = resp.json().await.expect("valid JSON response");
|
||||
assert_eq!(body["id"], "resp-test-001");
|
||||
assert_eq!(body["object"], "response");
|
||||
assert_eq!(body["model"], "test-model");
|
||||
assert_eq!(body["status"], "completed");
|
||||
assert_eq!(
|
||||
body["output"][0]["content"][0]["text"],
|
||||
"Hello from mock backend"
|
||||
);
|
||||
// Usage shape is the Responses-specific (input/output_tokens),
|
||||
// not the chat-completions one (prompt/completion_tokens). Asserts
|
||||
// the proxy didn't accidentally route through the wrong handler.
|
||||
assert_eq!(body["usage"]["total_tokens"], 10);
|
||||
assert!(body["usage"].get("input_tokens").is_some());
|
||||
}
|
||||
|
||||
/// A request that targets a model not present in the catalogue gets
|
||||
/// 404 from the router. This matches the chat-completions handler's
|
||||
/// behaviour — same error path, same status code, so a client can
|
||||
/// share retry logic across the two routes.
|
||||
#[tokio::test]
|
||||
async fn test_responses_model_not_found() {
|
||||
let mock_url = common::spawn_mock_neuron().await;
|
||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(format!("{gw_url}/v1/responses"))
|
||||
.json(&json!({
|
||||
"model": "not-in-catalogue",
|
||||
"input": "Hi"
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 404);
|
||||
}
|
||||
|
||||
/// A request body without a `model` field can't be routed; the
|
||||
/// gateway returns 400 before reaching a backend. Same as the
|
||||
/// chat-completions handler — extracted via the same `extract_model`
|
||||
/// helper.
|
||||
#[tokio::test]
|
||||
async fn test_responses_missing_model_field() {
|
||||
let mock_url = common::spawn_mock_neuron().await;
|
||||
let gw_url = common::spawn_gateway(&mock_url).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(format!("{gw_url}/v1/responses"))
|
||||
.json(&json!({
|
||||
"input": "Hi"
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 400);
|
||||
}
|
||||
@@ -51,18 +51,18 @@ async fn test_streaming_sse_passthrough() {
|
||||
}
|
||||
|
||||
assert!(
|
||||
chunks.len() > chunk_count,
|
||||
"expected more than {} chunks (got {}): {:?}",
|
||||
chunk_count,
|
||||
chunks.len() >= chunk_count + 1,
|
||||
"expected at least {} chunks (got {}): {:?}",
|
||||
chunk_count + 1,
|
||||
chunks.len(),
|
||||
chunks,
|
||||
);
|
||||
|
||||
assert_eq!(chunks.last().unwrap(), "[DONE]");
|
||||
|
||||
for (i, chunk) in chunks.iter().enumerate().take(chunk_count) {
|
||||
for i in 0..chunk_count {
|
||||
let chunk_json: serde_json::Value =
|
||||
serde_json::from_str(chunk).expect("chunk should be valid JSON");
|
||||
serde_json::from_str(&chunks[i]).expect("chunk should be valid JSON");
|
||||
assert_eq!(
|
||||
chunk_json["choices"][0]["delta"]["content"],
|
||||
format!("token{i}")
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
[package]
|
||||
name = "helexa-acp"
|
||||
version = "0.1.16"
|
||||
edition = "2024"
|
||||
license = "Apache-2.0"
|
||||
repository = "https://git.lair.cafe/helexa/helexa"
|
||||
description = """
|
||||
Agent Client Protocol bridge for the helexa self-hosted LLM stack.
|
||||
Speaks ACP to ACP-compatible editor clients (Zed, etc.) and forwards
|
||||
the conversation to any OpenAI-compatible HTTP endpoint — defaulting
|
||||
to cortex (helexa's reverse-proxy / fleet gateway).
|
||||
"""
|
||||
|
||||
# This crate is intentionally self-contained — no dependencies on other
|
||||
# workspace crates (cortex-core, cortex-gateway, neuron). The goal is
|
||||
# a painless migration to a dedicated GitHub repo in the future if the
|
||||
# project grows beyond helexa's needs. All deps are crates.io.
|
||||
[dependencies]
|
||||
# `unstable_session_model` flips on the SessionModelState type and the
|
||||
# session/set_model RPC the model-picker dropdown in Zed needs. The
|
||||
# feature is upstream-marked unstable; we accept that risk because the
|
||||
# model picker is core UX and the alternative (rolling our own
|
||||
# extension method) drifts further from spec each time it moves.
|
||||
agent-client-protocol = { version = "0.12", features = ["unstable_session_model"] }
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "io-util", "process", "signal"] }
|
||||
reqwest = { version = "0.12", features = ["json", "stream", "rustls-tls"], default-features = false }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
toml = "0.8"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
anyhow = "1"
|
||||
thiserror = "2"
|
||||
async-trait = "0.1"
|
||||
futures = "0.3"
|
||||
tokio-stream = "0.1"
|
||||
tokio-util = { version = "0.7", features = ["rt"] }
|
||||
eventsource-stream = "0.2"
|
||||
async-stream = "0.3"
|
||||
url = { version = "2", features = ["serde"] }
|
||||
# Already transitively pulled via the ACP SDK; declared directly so we
|
||||
# can format ISO 8601 timestamps for `SessionInfo.updated_at` in the
|
||||
# session/list response.
|
||||
chrono = { version = "0.4", default-features = false, features = ["std"] }
|
||||
|
||||
[[bin]]
|
||||
name = "helexa-acp"
|
||||
path = "src/main.rs"
|
||||
@@ -1,546 +0,0 @@
|
||||
# helexa-acp
|
||||
|
||||
ACP (Agent Client Protocol) bridge for editors like
|
||||
[Zed](https://zed.dev). Lets you point your editor's agent panel at
|
||||
**any combination** of OpenAI-compatible, OpenAI Responses, and
|
||||
Anthropic Messages endpoints — public APIs, private LAN deployments,
|
||||
local Ollama / LM Studio — and switch between them per session via a
|
||||
model dropdown.
|
||||
|
||||
The "missing ACP binary" for users who don't want to be locked into
|
||||
one vendor's agent client.
|
||||
|
||||
```
|
||||
┌───────────────────────────────────┐
|
||||
│ Zed (or any ACP editor client) │
|
||||
└────────────┬──────────────────────┘
|
||||
│ stdio JSON-RPC (ACP)
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ helexa-acp │ ← one binary, multi-endpoint
|
||||
└─────┬───────────┘
|
||||
│ HTTP / SSE
|
||||
┌────────┼─────────────┬──────────────┬──────────────┐
|
||||
▼ ▼ ▼ ▼ ▼
|
||||
cortex/ OpenAI Anthropic OpenRouter LM Studio
|
||||
neuron Responses Messages
|
||||
(self- (gpt-5,…) (Claude)
|
||||
hosted)
|
||||
```
|
||||
|
||||
## What it does
|
||||
|
||||
- **Speaks ACP** over stdio to editor clients (Zed today; any future
|
||||
ACP client tomorrow).
|
||||
- **Multi-endpoint** — one config file lists every LLM endpoint
|
||||
you want available; pick one per session via the model dropdown
|
||||
(`endpoint:model` selector).
|
||||
- **Three wire formats**: `openai-chat` (the broadly compatible
|
||||
default), `openai-responses` (newer OpenAI surface), and
|
||||
`anthropic-messages` (Claude). Each is a separate provider impl
|
||||
in `src/provider/`; adding a fourth (Gemini, Ollama native, …) is
|
||||
one file plus a `WireApi` enum variant.
|
||||
- **Built-in tools**: `read_file`, `write_file`, `edit_file`,
|
||||
`list_dir`, `bash`. Permission-gated by default; the editor user
|
||||
approves writes/shell per-call.
|
||||
- **Three session modes**: Default (gated), Bypass Permissions
|
||||
(auto-allow), and Plan (write-only-to-plan-dir, no shell).
|
||||
- **Vision** — drag-drop images into the agent panel against any
|
||||
vision-capable model.
|
||||
- **Session resume** — multi-day conversations survive editor
|
||||
restarts via on-disk transcript persistence.
|
||||
- **Context compaction** — rolling history stays inside the model's
|
||||
context window automatically so long sessions on small-context
|
||||
local models don't fall over.
|
||||
|
||||
## Install
|
||||
|
||||
### From source
|
||||
|
||||
```sh
|
||||
git clone https://git.lair.cafe/helexa/helexa.git
|
||||
cd helexa
|
||||
cargo install --path crates/helexa-acp
|
||||
# Binary lands at ~/.cargo/bin/helexa-acp
|
||||
```
|
||||
|
||||
### Pre-built RPM (Fedora 43)
|
||||
|
||||
```sh
|
||||
dnf copr enable helexa/helexa
|
||||
dnf install helexa-acp
|
||||
```
|
||||
|
||||
The COPR project bundles helexa-acp alongside the cortex gateway
|
||||
and helexa-neuron flavours; install only the package(s) you need.
|
||||
|
||||
## Quick start
|
||||
|
||||
The fastest path: env-var single-endpoint config.
|
||||
|
||||
```sh
|
||||
export HELEXA_ACP_BASE_URL=http://hanzalova.internal:31313/v1
|
||||
export HELEXA_ACP_MODEL=Qwen/Qwen3.6-27B
|
||||
helexa-acp # speaks ACP over stdin/stdout; not interactive
|
||||
```
|
||||
|
||||
Then in Zed (`~/.config/zed/settings.json`):
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"agent_servers": {
|
||||
"helexa": {
|
||||
"command": "helexa-acp",
|
||||
"args": []
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Restart Zed → open the agent panel → pick "helexa" → start
|
||||
chatting. Tool calls (file reads, writes, bash) prompt for
|
||||
permission per-call in Default mode.
|
||||
|
||||
That's the minimum. The full config story below is what unlocks
|
||||
the multi-endpoint dropdown.
|
||||
|
||||
## Multi-endpoint config
|
||||
|
||||
Copy `helexa-acp.example.toml` from this repo to
|
||||
`$XDG_CONFIG_HOME/helexa-acp/config.toml` (typically
|
||||
`~/.config/helexa-acp/config.toml`) and edit:
|
||||
|
||||
```toml
|
||||
default_endpoint = "helexa"
|
||||
|
||||
[[endpoints]]
|
||||
name = "helexa"
|
||||
base_url = "http://hanzalova.internal:31313/v1"
|
||||
wire_api = "openai-chat"
|
||||
default_model = "Qwen/Qwen3.6-27B"
|
||||
max_tokens = 8192
|
||||
context_window = 32768
|
||||
|
||||
[[endpoints]]
|
||||
name = "openrouter"
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
wire_api = "openai-chat"
|
||||
api_key_env = "OPENROUTER_API_KEY"
|
||||
default_model = "anthropic/claude-opus-4"
|
||||
|
||||
[[endpoints]]
|
||||
name = "anthropic"
|
||||
base_url = "https://api.anthropic.com/v1"
|
||||
wire_api = "anthropic-messages"
|
||||
api_key_env = "ANTHROPIC_API_KEY"
|
||||
default_model = "claude-opus-4"
|
||||
```
|
||||
|
||||
Restart Zed. The model dropdown lists every model from every
|
||||
configured endpoint with the `endpoint:model` selector
|
||||
(`helexa:Qwen/Qwen3.6-27B`, `openrouter:anthropic/claude-opus-4`,
|
||||
…). Switch mid-session; the next prompt routes to the new endpoint.
|
||||
|
||||
When only one endpoint is configured the prefix is dropped (model
|
||||
ids appear bare).
|
||||
|
||||
### Selector syntax
|
||||
|
||||
The `model` field on every internal request is parsed as
|
||||
`<endpoint>:<model>`:
|
||||
|
||||
- `openrouter:gpt-4o` → routes to the `openrouter` endpoint,
|
||||
model `gpt-4o`.
|
||||
- `helexa/large` → no colon → falls through to whichever endpoint
|
||||
is named in `default_endpoint`, model `helexa/large`.
|
||||
- `:gpt-5` → leading colon → also falls through to default.
|
||||
|
||||
## Endpoint cookbook
|
||||
|
||||
Copy-pasteable blocks. Mix and match.
|
||||
|
||||
### cortex / neuron (self-hosted)
|
||||
|
||||
```toml
|
||||
[[endpoints]]
|
||||
name = "helexa"
|
||||
base_url = "http://hanzalova.internal:31313/v1"
|
||||
wire_api = "openai-chat"
|
||||
default_model = "Qwen/Qwen3.6-27B"
|
||||
max_tokens = 8192
|
||||
context_window = 32768
|
||||
```
|
||||
|
||||
Use `openai-responses` instead of `openai-chat` once cortex 0.1.16+
|
||||
is deployed and you want the Responses API surface (vision item
|
||||
shape, structured reasoning items, etc.).
|
||||
|
||||
### OpenAI directly
|
||||
|
||||
```toml
|
||||
[[endpoints]]
|
||||
name = "openai"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
wire_api = "openai-responses"
|
||||
api_key_env = "OPENAI_API_KEY"
|
||||
default_model = "gpt-5"
|
||||
```
|
||||
|
||||
`openai-responses` is the right choice for current OpenAI models;
|
||||
`openai-chat` works against legacy GPT-3.5/4 deployments and
|
||||
anything labelled "chat completions".
|
||||
|
||||
### Anthropic directly
|
||||
|
||||
```toml
|
||||
[[endpoints]]
|
||||
name = "anthropic"
|
||||
base_url = "https://api.anthropic.com/v1"
|
||||
wire_api = "anthropic-messages"
|
||||
api_key_env = "ANTHROPIC_API_KEY"
|
||||
default_model = "claude-opus-4"
|
||||
```
|
||||
|
||||
helexa-acp sends `x-api-key` + `anthropic-version: 2023-06-01`
|
||||
automatically. The `api_key_env` indirection keeps your key out of
|
||||
the config file.
|
||||
|
||||
### OpenRouter (multi-vendor proxy)
|
||||
|
||||
```toml
|
||||
[[endpoints]]
|
||||
name = "openrouter"
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
wire_api = "openai-chat"
|
||||
api_key_env = "OPENROUTER_API_KEY"
|
||||
default_model = "anthropic/claude-opus-4"
|
||||
```
|
||||
|
||||
OpenRouter speaks OpenAI-compat for every model it fronts, so
|
||||
`openai-chat` is the right wire format regardless of the
|
||||
underlying vendor.
|
||||
|
||||
### LM Studio (local)
|
||||
|
||||
```toml
|
||||
[[endpoints]]
|
||||
name = "lmstudio"
|
||||
base_url = "http://localhost:1234/v1"
|
||||
wire_api = "openai-chat"
|
||||
default_model = "auto"
|
||||
```
|
||||
|
||||
LM Studio's "auto" model id picks whatever's loaded. Same shape
|
||||
works for Ollama in compat mode (`http://localhost:11434/v1`) and
|
||||
vLLM.
|
||||
|
||||
### Multiple cortex deployments
|
||||
|
||||
```toml
|
||||
[[endpoints]]
|
||||
name = "lan"
|
||||
base_url = "http://hanzalova.internal:31313/v1"
|
||||
wire_api = "openai-chat"
|
||||
default_model = "Qwen/Qwen3.6-27B"
|
||||
|
||||
[[endpoints]]
|
||||
name = "cloud"
|
||||
base_url = "https://cortex.example.com/v1"
|
||||
wire_api = "openai-chat"
|
||||
api_key_env = "CLOUD_CORTEX_KEY"
|
||||
default_model = "Qwen/Qwen3-VL-8B"
|
||||
```
|
||||
|
||||
Use the `endpoint:model` selector to switch between them mid-session.
|
||||
|
||||
## Zed setup
|
||||
|
||||
`~/.config/zed/settings.json`:
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"agent_servers": {
|
||||
"helexa": {
|
||||
"command": "helexa-acp"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Optional environment overrides for the binary:
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"agent_servers": {
|
||||
"helexa": {
|
||||
"command": "helexa-acp",
|
||||
"env": {
|
||||
"HELEXA_ACP_LOG_FILE": "/tmp/helexa-acp.log",
|
||||
"RUST_LOG": "helexa_acp=debug"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`HELEXA_ACP_LOG_FILE` is the one you actually want — Zed doesn't
|
||||
surface the agent's stderr, so without that env var debug output is
|
||||
invisible. Point it at a file you can `tail -f`.
|
||||
|
||||
After restarting Zed: ⌘+? (or wherever your "Open Agent Panel"
|
||||
binding is) → select "helexa" → the model dropdown populates from
|
||||
your config → start prompting.
|
||||
|
||||
## Modes
|
||||
|
||||
Three session modes ship; the user picks via Zed's mode dropdown
|
||||
on the agent panel.
|
||||
|
||||
| Mode | Reads | Writes | Bash | Permission prompts |
|
||||
|------|-------|--------|------|--------------------|
|
||||
| **Default** | ✓ | with prompt | with prompt | per call |
|
||||
| **Bypass Permissions** | ✓ | ✓ | ✓ | never |
|
||||
| **Plan** | ✓ | only into plan dir | disabled | never (plan-dir writes auto-allow) |
|
||||
|
||||
### Default
|
||||
|
||||
Reads are always allowed (`read_file`, `list_dir` are
|
||||
unrestricted). Writes and shell commands prompt the user before
|
||||
running. The intended baseline for any session where the agent
|
||||
might do something you'd rather review first.
|
||||
|
||||
### Bypass Permissions
|
||||
|
||||
Auto-allow every tool call. Use for agentic loops you trust — bulk
|
||||
edits across many files, scripted workflows, prepared session
|
||||
templates. Never for code the agent hasn't seen before.
|
||||
|
||||
### Plan
|
||||
|
||||
The "draft an implementation plan before you write code" mode.
|
||||
Available tools:
|
||||
|
||||
- `read_file`, `list_dir`: unrestricted (read the codebase).
|
||||
- `write_file`, `edit_file`: allowed *only* under
|
||||
`$XDG_DATA_HOME/helexa-acp/plans/<project-id>/`. Any path
|
||||
outside that returns "plan mode: writes are restricted to …"
|
||||
back to the model so it self-corrects.
|
||||
- `bash`: disabled outright. Returns "plan mode: shell execution
|
||||
is disabled" if attempted.
|
||||
|
||||
When the plan is complete, the model presents a 3-option menu:
|
||||
|
||||
1. **Bypass Permissions** — implement the plan now, no prompts.
|
||||
2. **Default** — implement now with per-tool prompts.
|
||||
3. **Plan** (stay here) — refine the plan with more guidance.
|
||||
|
||||
Switch the mode dropdown to your preference and reply to proceed.
|
||||
|
||||
## Tools
|
||||
|
||||
Five tools, defined in `src/tools.rs`:
|
||||
|
||||
| Tool | Args | Gated in Default? |
|
||||
|------|------|-------------------|
|
||||
| `read_file` | `path`, `line?`, `limit?` | no |
|
||||
| `list_dir` | `path` | no |
|
||||
| `write_file` | `path`, `content` | yes |
|
||||
| `edit_file` | `path`, `old_text`, `new_text` | yes |
|
||||
| `bash` | `command`, `cwd?` | yes |
|
||||
|
||||
### Path handling
|
||||
|
||||
`~`, `~/`, `$HOME`, and `$HOME/` are expanded server-side before
|
||||
the path reaches ACP or local fs. Lets the model emit
|
||||
`~/git/repo/file.rs` and have it Just Work.
|
||||
|
||||
`read_file` first tries the editor's filesystem (ACP's
|
||||
`fs/read_text_file` — respects open buffers, workspace overlays,
|
||||
etc.). If that fails — typically because the path is outside Zed's
|
||||
workspace boundary — it falls back to `std::fs::read_to_string`.
|
||||
This lets the agent pull in shared material like
|
||||
`~/git/architecture/generic.md` from a different project's
|
||||
session.
|
||||
|
||||
The fallback is logged at warn level so you can see when it kicks
|
||||
in.
|
||||
|
||||
### Tool dispatch
|
||||
|
||||
Tool descriptions reach the model through a Qwen3 Hermes-format
|
||||
`# Tools` block injected into the system prompt — cortex/neuron
|
||||
pass the OpenAI `tools` request field through to the encoder
|
||||
unread, so we work the model into emitting `<tool_call>{json}</tool_call>`
|
||||
markers it then parses out of the content stream. This applies to
|
||||
the helexa wire format; OpenAI / Anthropic endpoints with native
|
||||
tool support would use their own paths once they're wired in.
|
||||
|
||||
The parser is tolerant: malformed JSON (trailing braces, missing
|
||||
`name`, name nested in `arguments`) gets a repair pass; if that
|
||||
fails the call surfaces as a "Malformed tool call" card in Zed and
|
||||
the model gets a synthetic error result so it can self-correct.
|
||||
|
||||
## Session resume
|
||||
|
||||
helexa-acp persists every session to
|
||||
`$XDG_DATA_HOME/helexa-acp/sessions/<id>.json`. Zed's `session/list`
|
||||
RPC asks helexa-acp to enumerate them on workspace open;
|
||||
`session/load` rehydrates and replays the transcript as
|
||||
`session/update` notifications so the agent panel renders the
|
||||
prior conversation.
|
||||
|
||||
Behaviour:
|
||||
|
||||
- Persisted per-round, so a mid-turn agent stall (long bash, wedged
|
||||
ACP roundtrip) doesn't lose earlier rounds.
|
||||
- Survives editor restart and the helexa-acp binary upgrading
|
||||
between versions.
|
||||
- Project-scoped: only sessions whose `cwd` matches the workspace
|
||||
are listed.
|
||||
|
||||
To wipe history: `rm -rf $XDG_DATA_HOME/helexa-acp/sessions/`.
|
||||
|
||||
## Context compaction
|
||||
|
||||
When an endpoint sets `context_window`, helexa-acp projects the
|
||||
rolling history into a token budget before each request — old
|
||||
`ToolResult` content (read_file payloads are the worst offenders)
|
||||
gets elided to one-line markers, preserving `tool_call_id` pairing
|
||||
so the wire schema stays valid.
|
||||
|
||||
System prompts, user turns, and the most recent ~4 messages are
|
||||
never elided. The full history stays on disk; compaction is a
|
||||
per-request projection, not a destructive edit.
|
||||
|
||||
Set `context_window = 32768` for a 32 K Qwen3, `131072` for a
|
||||
modern Claude, etc. With `max_tokens` also set, the budget is
|
||||
`context_window - max_tokens - 512_safety`.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "default endpoint 'helexa' has no usable provider — check config"
|
||||
|
||||
The named default endpoint failed to construct. Usually:
|
||||
|
||||
- `api_key_env` references a variable that isn't set in the env
|
||||
Zed launched helexa-acp with.
|
||||
- The TOML's `wire_api` is misspelled (only `openai-chat`,
|
||||
`openai-responses`, `anthropic-messages` are accepted).
|
||||
|
||||
Test by running `helexa-acp` directly from a shell — startup
|
||||
errors land on stderr.
|
||||
|
||||
### Model dropdown is empty
|
||||
|
||||
Each provider's `list_models` failed at startup. Look at
|
||||
`HELEXA_ACP_LOG_FILE` for "list_models failed; this endpoint's
|
||||
models won't appear in the picker". Likely the endpoint URL is
|
||||
wrong, the API key is invalid, or the upstream `/v1/models`
|
||||
endpoint isn't responding.
|
||||
|
||||
The agent still works against `default_model` even when the
|
||||
dropdown is empty — list-models is for picking, not routing.
|
||||
|
||||
### "prompt_too_long" / agent stalls mid-conversation
|
||||
|
||||
You hit the model's context window. Set `context_window` on the
|
||||
endpoint and helexa-acp will compact before sending. The log line
|
||||
`context compaction applied` confirms it's running; if it fires
|
||||
but the upstream still rejects, the compaction heuristic
|
||||
under-counted and the budget needs tuning down.
|
||||
|
||||
### Reading files outside the workspace returns "not found"
|
||||
|
||||
Zed's `fs/read_text_file` is workspace-scoped. helexa-acp falls
|
||||
back to local `std::fs` automatically when that fails — look for
|
||||
`fs/read_text_file failed; falling back to local std::fs` in the
|
||||
log. If even local read fails, the file genuinely doesn't exist
|
||||
or the user process lacks permissions.
|
||||
|
||||
### Tool calls render as text instead of structured cards
|
||||
|
||||
The model is emitting `<tool_call>` markers that the parser can't
|
||||
decode. Two common causes:
|
||||
|
||||
1. The system prompt isn't reaching the model (cortex/neuron's
|
||||
tool-block injection didn't fire). Confirm with
|
||||
`RUST_LOG=helexa_acp=debug` and look at the outgoing
|
||||
`POST /chat/completions` body.
|
||||
2. The model itself is too small / undertrained to follow the
|
||||
Hermes format reliably. helexa-acp has shape-based name
|
||||
inference and JSON repair, but there's a floor below which
|
||||
nothing helps.
|
||||
|
||||
### Plan-mode writes refused even inside the plan dir
|
||||
|
||||
The path comparison is byte-for-byte. If the model emits a path
|
||||
with `~` and the plan_dir has the expanded form, expansion runs
|
||||
*before* the comparison — but resolved-vs-symlinked-path
|
||||
mismatches can still bite. The error message names the attempted
|
||||
path and the expected prefix so you can compare directly.
|
||||
|
||||
## Architecture
|
||||
|
||||
Source layout under `crates/helexa-acp/src/`:
|
||||
|
||||
| File | Responsibility |
|
||||
|------|----------------|
|
||||
| `main.rs` | tokio + Stdio transport. Builds providers, hands off to `agent::Agent` |
|
||||
| `config.rs` | TOML + env-fallback config, endpoint resolver |
|
||||
| `agent.rs` | ACP handlers (initialize, session/new, session/prompt, session/cancel, session/set_mode, session/set_model, session/load, session/list), prompt loop with tool-call recursion |
|
||||
| `session.rs` | Per-session state map (Arc<RwLock<HashMap<…>>>) |
|
||||
| `store.rs` | On-disk session persistence, plan-dir resolution |
|
||||
| `prompt.rs` | System-prompt assembly, plan-mode addendum |
|
||||
| `tools.rs` | Tool schemas + shape-based name inference |
|
||||
| `tool_runner.rs` | Dispatch a single tool call through ACP client RPCs; permission gate |
|
||||
| `qwen3.rs` | Qwen3 Hermes tool-format parser (`<tool_call>` / `<think>` markers) |
|
||||
| `compaction.rs` | Token-budget compaction for the rolling history |
|
||||
| `path_util.rs` | `~` / `$HOME` expansion shared across every path-taking tool |
|
||||
| `provider/openai_chat.rs` | OpenAI chat completions provider |
|
||||
| `provider/openai_responses.rs` | OpenAI Responses API provider |
|
||||
| `provider/anthropic_messages.rs` | Anthropic Messages API provider |
|
||||
|
||||
### Adding a new wire format
|
||||
|
||||
1. New file under `src/provider/` implementing the `Provider`
|
||||
trait (encoder + SSE decoder).
|
||||
2. Add a `WireApi` variant in `config.rs`.
|
||||
3. Wire it into `build_provider` in `main.rs`.
|
||||
4. Done — every other module is wire-format-agnostic.
|
||||
|
||||
### Concurrency
|
||||
|
||||
- `Arc<RwLock<HashMap<SessionId, Arc<Mutex<SessionState>>>>>` —
|
||||
per-session mutex so concurrent requests across sessions don't
|
||||
contend; the map's RwLock is read-mostly.
|
||||
- Every tool call dispatched serially within a session (parallel
|
||||
dispatch would require Zed to handle interleaved permission
|
||||
prompts).
|
||||
- Provider streams are back-pressured by the consumer (bounded
|
||||
mpsc channels).
|
||||
|
||||
### Self-contained
|
||||
|
||||
The crate has no workspace-internal dependencies (no
|
||||
`cortex-core`, no `cortex-gateway`). Migration to a dedicated
|
||||
GitHub repo for cross-platform CI / cargo-dist binaries is
|
||||
Cargo.toml-only.
|
||||
|
||||
## Status
|
||||
|
||||
- Stages 1–6 shipped: scaffold, agent loop, tools, modes, session
|
||||
resume, image input, model picker, three wire formats.
|
||||
- Stage 8 (RPM + multi-platform CI) tracked in the canonical plan;
|
||||
Linux x86_64 RPM ships today via the cortex monorepo's Gitea
|
||||
Actions.
|
||||
|
||||
## Contributing
|
||||
|
||||
Repository: https://git.lair.cafe/helexa/helexa (`crates/helexa-acp/`).
|
||||
Issues / PRs welcome. The canonical staged plan is in
|
||||
`~/.claude/plans/plan-the-per-device-worker-abstract-micali.md` on
|
||||
the maintainer's machine; the substages 3a–3e and 6a/6b that the
|
||||
canonical plan didn't anticipate are documented in commit messages.
|
||||
|
||||
CI: `cargo fmt --check --all`, `cargo clippy --workspace -- -D
|
||||
warnings`, `cargo test --workspace` must all pass before merge.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,425 +0,0 @@
|
||||
//! Rolling-conversation compaction for small-context local models.
|
||||
//!
|
||||
//! The tool-call loop in [`crate::agent`] grows the message vec it
|
||||
//! sends upstream every round. On a frontier model that's fine; on a
|
||||
//! 32 K Qwen3 the first few `read_file` results can push the prompt
|
||||
//! past the model's context window, at which point cortex/neuron
|
||||
//! refuses with `prompt_too_long` and the whole turn dies. Long-form
|
||||
//! local agents are unusable without something here.
|
||||
//!
|
||||
//! Strategy (intentionally simple — no LLM-summarization round-trip,
|
||||
//! no tokenizer dependency):
|
||||
//!
|
||||
//! 1. **Protect** the things the model cannot reason without:
|
||||
//! - The system prompt (idx 0).
|
||||
//! - Every `Role::User` turn (the user's intent — irreplaceable).
|
||||
//! - The last [`KEEP_TAIL`] messages (most recent rounds stay
|
||||
//! verbatim so the model can keep working on what it just
|
||||
//! observed).
|
||||
//! 2. **Elide** older `Role::Assistant` prose and older `Role::Tool`
|
||||
//! result content. The structure stays — `tool_call_id`s, tool
|
||||
//! names, and argument JSON survive intact — so OpenAI's strict
|
||||
//! `tool_calls` ↔ `tool` pairing schema remains satisfied. Only
|
||||
//! the *payload* shrinks to a one-line marker.
|
||||
//! 3. Walk oldest→newest, recomputing the budget after each elision.
|
||||
//! Stop as soon as we fit; we don't compact more than necessary.
|
||||
//! 4. If we still exceed budget after eliding everything we're
|
||||
//! allowed to, return what we have. The upstream will surface a
|
||||
//! `prompt_too_long` error and the user can intervene; that's
|
||||
//! better than silently dropping content the model needs.
|
||||
//!
|
||||
//! Token estimation uses a `chars / 3.5` heuristic — conservative
|
||||
//! (over-estimates tokens slightly) so we compact a touch early
|
||||
//! rather than a touch late.
|
||||
|
||||
use crate::provider::{Message, MessageContent, MessagePart, Role};
|
||||
|
||||
/// Most-recent N messages that are never elided. Roughly "the
|
||||
/// current tool round in flight" — assistant turn that called the
|
||||
/// tools + each tool result + a bit of slack.
|
||||
const KEEP_TAIL: usize = 4;
|
||||
|
||||
/// Below this content size we don't bother eliding — the savings
|
||||
/// don't outweigh the loss of detail. Roughly 60–80 tokens.
|
||||
const ELIDE_MIN_CHARS: usize = 256;
|
||||
|
||||
/// Roughly tokens-per-character for English + code mixed in. The
|
||||
/// actual per-tokenizer ratio varies (GPT-4o ≈ 4 chars/token on
|
||||
/// English prose, ≈ 3 chars/token on code-heavy text). We pick a
|
||||
/// value on the conservative end so the budget check fires *before*
|
||||
/// the upstream tokenizer says no.
|
||||
const CHARS_PER_TOKEN: f32 = 3.5;
|
||||
|
||||
/// Per-message envelope overhead (role + JSON framing). Comes out
|
||||
/// to a few tokens; tiny but it adds up across long histories.
|
||||
const ENVELOPE_TOKENS: usize = 8;
|
||||
|
||||
/// Rough per-image token cost used by the budget estimator. Real
|
||||
/// vision tokenizers vary widely (256–1024 tokens for typical
|
||||
/// resolutions on Qwen3-VL, OpenAI's `low`/`high` detail toggles
|
||||
/// pick between ~85 and ~1000+). 512 is a defensible middle that
|
||||
/// keeps compaction from treating images as free.
|
||||
const IMAGE_TOKENS_APPROX: usize = 512;
|
||||
|
||||
/// Stats reported back from [`compact_to_budget`] for the caller to
|
||||
/// log. The numbers are estimates (see [`estimate_tokens`]), so
|
||||
/// don't compare them to upstream-reported token counts as if they
|
||||
/// were exact.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct CompactionStats {
|
||||
/// Estimated tokens in the input messages.
|
||||
pub original_tokens: usize,
|
||||
/// Estimated tokens after compaction. Equal to `original_tokens`
|
||||
/// when no compaction was needed.
|
||||
pub final_tokens: usize,
|
||||
/// Number of messages whose content was elided. Zero is the
|
||||
/// hot path (nothing to do).
|
||||
pub elided_messages: usize,
|
||||
}
|
||||
|
||||
impl CompactionStats {
|
||||
fn unchanged(tokens: usize) -> Self {
|
||||
Self {
|
||||
original_tokens: tokens,
|
||||
final_tokens: tokens,
|
||||
elided_messages: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Approximate token count for one message. Sums the textual
|
||||
/// payload's chars, divides by [`CHARS_PER_TOKEN`], and adds an
|
||||
/// envelope constant. Cheap (no allocation) so safe to call once per
|
||||
/// message per round.
|
||||
pub fn estimate_tokens(msg: &Message) -> usize {
|
||||
let chars = match &msg.content {
|
||||
MessageContent::Text { text } => text.len(),
|
||||
MessageContent::MultiPart { parts } => parts
|
||||
.iter()
|
||||
.map(|p| match p {
|
||||
MessagePart::Text { text } => text.len(),
|
||||
// Each image is one block in the context window; the
|
||||
// upstream tokenizer handles the real cost (and it
|
||||
// varies wildly by model — Qwen3-VL uses ~256-1024
|
||||
// tokens per image depending on size). Take a
|
||||
// middle estimate so the budget tracker doesn't
|
||||
// pretend images are free.
|
||||
MessagePart::Image(_) => IMAGE_TOKENS_APPROX * CHARS_PER_TOKEN as usize,
|
||||
})
|
||||
.sum(),
|
||||
MessageContent::ToolCalls { text, calls } => {
|
||||
let txt = text.as_deref().map(|s| s.len()).unwrap_or(0);
|
||||
let calls_size: usize = calls
|
||||
.iter()
|
||||
.map(|c| c.name.len() + c.arguments.len() + c.id.len())
|
||||
.sum();
|
||||
txt + calls_size
|
||||
}
|
||||
MessageContent::ToolResult {
|
||||
tool_call_id,
|
||||
content,
|
||||
} => tool_call_id.len() + content.len(),
|
||||
};
|
||||
((chars as f32 / CHARS_PER_TOKEN) as usize) + ENVELOPE_TOKENS
|
||||
}
|
||||
|
||||
/// Sum of [`estimate_tokens`] across all messages.
|
||||
pub fn total_tokens(messages: &[Message]) -> usize {
|
||||
messages.iter().map(estimate_tokens).sum()
|
||||
}
|
||||
|
||||
/// Project `messages` into a vec whose estimated token count fits in
|
||||
/// `budget` tokens. Returns the projection plus stats about what
|
||||
/// was done. When the input already fits, the projection is a clone
|
||||
/// of the input and stats report zero elisions.
|
||||
///
|
||||
/// See module docs for the strategy and protected set.
|
||||
pub fn compact_to_budget(messages: &[Message], budget: usize) -> (Vec<Message>, CompactionStats) {
|
||||
let original = total_tokens(messages);
|
||||
if original <= budget {
|
||||
return (messages.to_vec(), CompactionStats::unchanged(original));
|
||||
}
|
||||
|
||||
let mut out = messages.to_vec();
|
||||
let len = out.len();
|
||||
let tail_start = len.saturating_sub(KEEP_TAIL);
|
||||
let mut elided = 0usize;
|
||||
|
||||
// Two passes. First pass: ToolResult contents (largest savings
|
||||
// per elision — read_file payloads land here). Second pass: long
|
||||
// Assistant prose. We don't interleave because eliding a long
|
||||
// assistant turn before a really old read_file would do less
|
||||
// good per elision; oldest-first ordering is enforced *within*
|
||||
// each pass instead.
|
||||
for pass in 0..2 {
|
||||
for i in 1..tail_start {
|
||||
if matches!(out[i].role, Role::User) {
|
||||
continue;
|
||||
}
|
||||
let target_pass_2 = matches!(
|
||||
&out[i].content,
|
||||
MessageContent::Text { .. } | MessageContent::ToolCalls { .. }
|
||||
);
|
||||
let target_pass_1 = matches!(&out[i].content, MessageContent::ToolResult { .. });
|
||||
let in_pass = (pass == 0 && target_pass_1) || (pass == 1 && target_pass_2);
|
||||
if !in_pass {
|
||||
continue;
|
||||
}
|
||||
if elide_in_place(&mut out[i]) {
|
||||
elided += 1;
|
||||
if total_tokens(&out) <= budget {
|
||||
let final_tokens = total_tokens(&out);
|
||||
return (
|
||||
out,
|
||||
CompactionStats {
|
||||
original_tokens: original,
|
||||
final_tokens,
|
||||
elided_messages: elided,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let final_tokens = total_tokens(&out);
|
||||
(
|
||||
out,
|
||||
CompactionStats {
|
||||
original_tokens: original,
|
||||
final_tokens,
|
||||
elided_messages: elided,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Shrink one message's payload while keeping its structural role
|
||||
/// (so tool_call_id pairing survives). Returns `true` when the
|
||||
/// message changed.
|
||||
///
|
||||
/// - `ToolResult.content` → `(elided: N bytes of tool result)`
|
||||
/// - `ToolCalls.text` → `(elided: N bytes of assistant prose)`
|
||||
/// - `Text` (assistant) → `(elided: N bytes of assistant prose)`
|
||||
///
|
||||
/// Already-tiny payloads are skipped — eliding a 50-byte string
|
||||
/// would *grow* it once the marker is in place.
|
||||
fn elide_in_place(msg: &mut Message) -> bool {
|
||||
match &mut msg.content {
|
||||
MessageContent::ToolResult { content, .. } => {
|
||||
if content.len() < ELIDE_MIN_CHARS {
|
||||
return false;
|
||||
}
|
||||
*content = format!("(elided: {} bytes of tool result)", content.len());
|
||||
true
|
||||
}
|
||||
MessageContent::ToolCalls { text, .. } => match text {
|
||||
Some(t) if t.len() >= ELIDE_MIN_CHARS => {
|
||||
*text = Some(format!("(elided: {} bytes of assistant prose)", t.len()));
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
},
|
||||
MessageContent::Text { text } => {
|
||||
if text.len() < ELIDE_MIN_CHARS {
|
||||
return false;
|
||||
}
|
||||
*text = format!("(elided: {} bytes of assistant prose)", text.len());
|
||||
true
|
||||
}
|
||||
MessageContent::MultiPart { .. } => {
|
||||
// MultiPart messages today only exist as User turns,
|
||||
// and User turns are protected by the role check in
|
||||
// `compact_to_budget` — so this branch is unreachable
|
||||
// for current call sites. Returning false keeps the
|
||||
// unreachable path benign if a future stage starts
|
||||
// emitting MultiPart on other roles.
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::ToolCall;
|
||||
|
||||
fn sys(text: &str) -> Message {
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text { text: text.into() },
|
||||
}
|
||||
}
|
||||
fn user(text: &str) -> Message {
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text { text: text.into() },
|
||||
}
|
||||
}
|
||||
fn assistant_text(text: &str) -> Message {
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::Text { text: text.into() },
|
||||
}
|
||||
}
|
||||
fn assistant_calls(text: Option<&str>, name: &str, args: &str, id: &str) -> Message {
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::ToolCalls {
|
||||
text: text.map(|s| s.to_string()),
|
||||
calls: vec![ToolCall {
|
||||
id: id.into(),
|
||||
name: name.into(),
|
||||
arguments: args.into(),
|
||||
}],
|
||||
},
|
||||
}
|
||||
}
|
||||
fn tool_result(id: &str, body: &str) -> Message {
|
||||
Message {
|
||||
role: Role::Tool,
|
||||
content: MessageContent::ToolResult {
|
||||
tool_call_id: id.into(),
|
||||
content: body.into(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn under_budget_is_a_no_op_clone() {
|
||||
let msgs = vec![sys("you are an agent"), user("hi"), assistant_text("hello")];
|
||||
let (out, stats) = compact_to_budget(&msgs, 10_000);
|
||||
assert_eq!(stats.elided_messages, 0);
|
||||
assert_eq!(stats.original_tokens, stats.final_tokens);
|
||||
assert_eq!(out.len(), msgs.len());
|
||||
// Strings unchanged.
|
||||
match &out[2].content {
|
||||
MessageContent::Text { text } => assert_eq!(text, "hello"),
|
||||
other => panic!("expected Text, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn elides_old_tool_result_before_old_assistant_prose() {
|
||||
// History: sys, user, assistant_calls, big_tool_result,
|
||||
// assistant_with_big_text, user, assistant_calls,
|
||||
// small_tool_result.
|
||||
// KEEP_TAIL=4 protects the last four; the big tool result
|
||||
// sits in the prunable range and should go first because
|
||||
// pass 0 (tool results) runs before pass 1 (prose).
|
||||
let big_result = "X".repeat(4096);
|
||||
let big_prose = "Y".repeat(2048);
|
||||
let msgs = vec![
|
||||
sys("preamble"),
|
||||
user("first ask"),
|
||||
assistant_calls(None, "read_file", r#"{"path":"/a"}"#, "c0"),
|
||||
tool_result("c0", &big_result),
|
||||
assistant_text(&big_prose),
|
||||
user("follow up"),
|
||||
assistant_calls(None, "read_file", r#"{"path":"/b"}"#, "c1"),
|
||||
tool_result("c1", "short result body"),
|
||||
];
|
||||
let before = total_tokens(&msgs);
|
||||
// Force compaction by setting budget well below current.
|
||||
let budget = before / 2;
|
||||
let (out, stats) = compact_to_budget(&msgs, budget);
|
||||
|
||||
assert!(
|
||||
stats.elided_messages >= 1,
|
||||
"expected at least one elision, got {stats:?}"
|
||||
);
|
||||
// The big tool result must be elided (oldest fat target).
|
||||
match &out[3].content {
|
||||
MessageContent::ToolResult { content, .. } => {
|
||||
assert!(
|
||||
content.starts_with("(elided:"),
|
||||
"tool result not elided: {content:?}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected ToolResult, got {other:?}"),
|
||||
}
|
||||
// Last four messages must be untouched.
|
||||
assert!(matches!(
|
||||
&out[out.len() - 1].content,
|
||||
MessageContent::ToolResult { content, .. } if content == "short result body"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn never_elides_system_or_user_turns() {
|
||||
let big_user = "U".repeat(8192);
|
||||
let msgs = vec![sys("preamble"), user(&big_user), assistant_text("ok")];
|
||||
let budget = 10; // way below — forces all possible elision
|
||||
let (out, _stats) = compact_to_budget(&msgs, budget);
|
||||
// System unchanged.
|
||||
match &out[0].content {
|
||||
MessageContent::Text { text } => assert_eq!(text, "preamble"),
|
||||
other => panic!("expected Text, got {other:?}"),
|
||||
}
|
||||
// User unchanged even though it's huge.
|
||||
match &out[1].content {
|
||||
MessageContent::Text { text } => assert_eq!(text.len(), big_user.len()),
|
||||
other => panic!("expected Text, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preserves_tool_call_id_pairing_after_elision() {
|
||||
// OpenAI strict mode rejects a tool-result whose tool_call_id
|
||||
// doesn't match a preceding assistant tool_call. Elision
|
||||
// must not break that linkage.
|
||||
let big = "Z".repeat(4096);
|
||||
let msgs = vec![
|
||||
sys("preamble"),
|
||||
user("first"),
|
||||
assistant_calls(None, "read_file", r#"{"path":"/a"}"#, "call_42"),
|
||||
tool_result("call_42", &big),
|
||||
// Tail messages.
|
||||
user("next"),
|
||||
assistant_calls(None, "read_file", r#"{"path":"/b"}"#, "call_43"),
|
||||
tool_result("call_43", "ok"),
|
||||
assistant_text("done"),
|
||||
];
|
||||
let budget = total_tokens(&msgs) / 3;
|
||||
let (out, _stats) = compact_to_budget(&msgs, budget);
|
||||
// The assistant call and its result both carry call_42.
|
||||
let call_id = match &out[2].content {
|
||||
MessageContent::ToolCalls { calls, .. } => calls[0].id.clone(),
|
||||
other => panic!("expected ToolCalls, got {other:?}"),
|
||||
};
|
||||
match &out[3].content {
|
||||
MessageContent::ToolResult { tool_call_id, .. } => {
|
||||
assert_eq!(tool_call_id, &call_id, "pairing broken");
|
||||
}
|
||||
other => panic!("expected ToolResult, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn estimate_tokens_grows_with_content() {
|
||||
let small = sys("hi");
|
||||
let large = sys(&"x".repeat(10_000));
|
||||
assert!(estimate_tokens(&large) > estimate_tokens(&small) * 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn elide_in_place_skips_short_content() {
|
||||
let mut m = tool_result("c0", "tiny");
|
||||
assert!(!elide_in_place(&mut m));
|
||||
match m.content {
|
||||
MessageContent::ToolResult { content, .. } => assert_eq!(content, "tiny"),
|
||||
other => panic!("expected ToolResult, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_best_effort_when_budget_unmeetable() {
|
||||
// Single huge user message that cannot be elided. Budget 10.
|
||||
// We don't error — we return what we have and let upstream
|
||||
// refuse the prompt with its own error.
|
||||
let big_user = "U".repeat(100_000);
|
||||
let msgs = vec![sys("preamble"), user(&big_user)];
|
||||
let (out, stats) = compact_to_budget(&msgs, 10);
|
||||
assert_eq!(out.len(), msgs.len());
|
||||
assert!(stats.final_tokens > 10, "still over budget by design");
|
||||
}
|
||||
}
|
||||
@@ -1,424 +0,0 @@
|
||||
//! Configuration for the helexa-acp bridge.
|
||||
//!
|
||||
//! Loaded from `$XDG_CONFIG_HOME/helexa-acp/config.toml` (or
|
||||
//! `~/.config/helexa-acp/config.toml` as a fallback). If no config file
|
||||
//! exists, falls back to building a single anonymous endpoint from env
|
||||
//! vars — that keeps "just point at one cortex" frictionless without
|
||||
//! requiring a config file on disk.
|
||||
//!
|
||||
//! The design goal is "the missing ACP binary for users with multiple
|
||||
//! API endpoints (possibly on a private LAN, possibly mixing wire
|
||||
//! types)". Hence: every endpoint is named, has its own wire API, and
|
||||
//! has its own default model. The agent's selected model id can be
|
||||
//! prefixed `endpoint:model` to route across endpoints; a bare
|
||||
//! `model` falls through to the configured `default_endpoint`.
|
||||
//!
|
||||
//! ### Example TOML
|
||||
//!
|
||||
//! ```toml
|
||||
//! default_endpoint = "helexa"
|
||||
//!
|
||||
//! [[endpoints]]
|
||||
//! name = "helexa"
|
||||
//! base_url = "http://hanzalova.internal:31313/v1"
|
||||
//! wire_api = "openai-chat"
|
||||
//! default_model = "helexa/large"
|
||||
//!
|
||||
//! [[endpoints]]
|
||||
//! name = "openrouter"
|
||||
//! base_url = "https://openrouter.ai/api/v1"
|
||||
//! wire_api = "openai-chat"
|
||||
//! api_key_env = "OPENROUTER_API_KEY"
|
||||
//! default_model = "anthropic/claude-opus-4"
|
||||
//!
|
||||
//! [[endpoints]]
|
||||
//! name = "lmstudio"
|
||||
//! base_url = "http://localhost:1234/v1"
|
||||
//! wire_api = "openai-chat"
|
||||
//! default_model = "auto"
|
||||
//! ```
|
||||
|
||||
use anyhow::{Context, anyhow};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
use url::Url;
|
||||
|
||||
const DEFAULT_BASE_URL: &str = "http://hanzalova.internal:31313/v1";
|
||||
const DEFAULT_MODEL: &str = "helexa/large";
|
||||
const DEFAULT_ENDPOINT_NAME: &str = "default";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
/// Name of the endpoint used when a request doesn't pick one
|
||||
/// explicitly. Must reference an entry in `endpoints`. Defaults to
|
||||
/// the first endpoint declared if unset.
|
||||
#[serde(default)]
|
||||
pub default_endpoint: Option<String>,
|
||||
/// Per-endpoint configuration. At least one entry is required.
|
||||
#[serde(default)]
|
||||
pub endpoints: Vec<EndpointConfig>,
|
||||
/// Optional path to a system-prompt file. When unset, the built-in
|
||||
/// default prompt from `prompt.rs` is used.
|
||||
#[serde(default)]
|
||||
pub system_prompt_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EndpointConfig {
|
||||
/// Short identifier used in `endpoint:model` routing and in logs.
|
||||
pub name: String,
|
||||
/// Base URL of the OpenAI-compatible API. Must include the `/v1`
|
||||
/// (or equivalent) suffix — paths like `chat/completions` and
|
||||
/// `models` are joined onto this.
|
||||
pub base_url: Url,
|
||||
/// Wire protocol the endpoint speaks. Phase 1 supports
|
||||
/// [`WireApi::OpenAiChat`] only; `openai-responses` and
|
||||
/// `anthropic-messages` land later behind their own providers.
|
||||
#[serde(default)]
|
||||
pub wire_api: WireApi,
|
||||
/// Model to use when the client hasn't picked one via
|
||||
/// `session/set_model`.
|
||||
#[serde(default)]
|
||||
pub default_model: Option<String>,
|
||||
/// Static API key to send as `Authorization: Bearer …`. Prefer
|
||||
/// `api_key_env` for anything sensitive — keys in plain TOML are a
|
||||
/// liability.
|
||||
#[serde(default)]
|
||||
pub api_key: Option<String>,
|
||||
/// Env var name to read for the API key. Resolved at startup so a
|
||||
/// missing env var yields a clear error rather than silent
|
||||
/// unauthenticated calls.
|
||||
#[serde(default)]
|
||||
pub api_key_env: Option<String>,
|
||||
/// Cap on the model's output tokens per turn. `None` lets the
|
||||
/// upstream pick its own default (cortex/neuron's default is
|
||||
/// often small enough to trip Zed's "Output Limit Reached" on
|
||||
/// long responses). Set to e.g. `32768` to let the model
|
||||
/// produce longer turns. Goes into the OpenAI `max_tokens`
|
||||
/// request field.
|
||||
#[serde(default)]
|
||||
pub max_tokens: Option<u64>,
|
||||
/// Model context window in tokens (prompt + response). When set,
|
||||
/// the agent compacts conversation history before each completion
|
||||
/// so the prompt fits within `context_window - max_tokens - safety`
|
||||
/// tokens — long sessions on small-context local models (Qwen3 at
|
||||
/// 32 K) survive past the first few tool-call rounds rather than
|
||||
/// dying with `prompt_too_long`. `None` disables compaction.
|
||||
#[serde(default)]
|
||||
pub context_window: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
pub enum WireApi {
|
||||
/// `POST {base}/chat/completions` returning OpenAI-format SSE.
|
||||
/// Compatible with cortex, LM Studio, Ollama (compat mode),
|
||||
/// OpenRouter, OpenAI itself.
|
||||
#[default]
|
||||
#[serde(rename = "openai-chat")]
|
||||
OpenAiChat,
|
||||
/// `POST {base}/responses` — OpenAI's newer Responses API. Not
|
||||
/// implemented yet; the variant is reserved so endpoint configs
|
||||
/// can be authored ahead of provider support.
|
||||
#[serde(rename = "openai-responses")]
|
||||
OpenAiResponses,
|
||||
/// `POST {base}/messages` — Anthropic format. Reserved.
|
||||
#[serde(rename = "anthropic-messages")]
|
||||
AnthropicMessages,
|
||||
}
|
||||
|
||||
impl EndpointConfig {
|
||||
/// Resolve the API key from `api_key` (literal) or `api_key_env`
|
||||
/// (env-var lookup). Returns `Ok(None)` when neither is set;
|
||||
/// `Err` when `api_key_env` references a missing variable.
|
||||
pub fn resolve_api_key(&self) -> anyhow::Result<Option<String>> {
|
||||
if let Some(literal) = &self.api_key {
|
||||
return Ok(Some(literal.clone()));
|
||||
}
|
||||
if let Some(var) = &self.api_key_env {
|
||||
return Ok(Some(std::env::var(var).with_context(|| {
|
||||
format!(
|
||||
"endpoint '{}' references missing env var {}",
|
||||
self.name, var
|
||||
)
|
||||
})?));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// `{base_url}/chat/completions`.
|
||||
pub fn chat_completions_url(&self) -> Url {
|
||||
join_segments(&self.base_url, &["chat", "completions"])
|
||||
}
|
||||
|
||||
/// `{base_url}/responses` — OpenAI Responses API endpoint.
|
||||
pub fn responses_url(&self) -> Url {
|
||||
join_segments(&self.base_url, &["responses"])
|
||||
}
|
||||
|
||||
/// `{base_url}/models`. Called from `Provider::list_models`, which
|
||||
/// Stage 4 wires into the model-picker dropdown; until then it's
|
||||
/// reachable code with no in-tree callers.
|
||||
#[allow(dead_code)]
|
||||
pub fn models_url(&self) -> Url {
|
||||
join_segments(&self.base_url, &["models"])
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Load from TOML at the standard config path, or build from env
|
||||
/// vars if no file exists. Env-fallback yields a single endpoint
|
||||
/// named `"default"`.
|
||||
pub fn load() -> anyhow::Result<Self> {
|
||||
let path = config_path();
|
||||
if let Some(path) = &path
|
||||
&& path.exists()
|
||||
{
|
||||
return Self::from_file(path);
|
||||
}
|
||||
Self::from_env()
|
||||
}
|
||||
|
||||
/// Single-endpoint config constructed from `HELEXA_ACP_BASE_URL`,
|
||||
/// `HELEXA_ACP_MODEL`, `HELEXA_ACP_API_KEY`,
|
||||
/// `HELEXA_ACP_SYSTEM_PROMPT_PATH`, `HELEXA_ACP_MAX_TOKENS`.
|
||||
pub fn from_env() -> anyhow::Result<Self> {
|
||||
let base_url = std::env::var("HELEXA_ACP_BASE_URL")
|
||||
.ok()
|
||||
.unwrap_or_else(|| DEFAULT_BASE_URL.into());
|
||||
let base_url = Url::parse(&base_url)
|
||||
.with_context(|| format!("HELEXA_ACP_BASE_URL is not a valid URL ({base_url})"))?;
|
||||
let default_model = std::env::var("HELEXA_ACP_MODEL")
|
||||
.ok()
|
||||
.unwrap_or_else(|| DEFAULT_MODEL.into());
|
||||
let api_key = std::env::var("HELEXA_ACP_API_KEY")
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty());
|
||||
let system_prompt_path = std::env::var("HELEXA_ACP_SYSTEM_PROMPT_PATH")
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(PathBuf::from);
|
||||
let max_tokens = std::env::var("HELEXA_ACP_MAX_TOKENS")
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|s| {
|
||||
s.parse::<u64>().with_context(|| {
|
||||
format!("HELEXA_ACP_MAX_TOKENS is not a positive integer ({s})")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
let context_window = std::env::var("HELEXA_ACP_CONTEXT_WINDOW")
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|s| {
|
||||
s.parse::<usize>().with_context(|| {
|
||||
format!("HELEXA_ACP_CONTEXT_WINDOW is not a positive integer ({s})")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(Self {
|
||||
default_endpoint: Some(DEFAULT_ENDPOINT_NAME.into()),
|
||||
endpoints: vec![EndpointConfig {
|
||||
name: DEFAULT_ENDPOINT_NAME.into(),
|
||||
base_url,
|
||||
wire_api: WireApi::OpenAiChat,
|
||||
default_model: Some(default_model),
|
||||
api_key,
|
||||
api_key_env: None,
|
||||
max_tokens,
|
||||
context_window,
|
||||
}],
|
||||
system_prompt_path,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_file(path: &Path) -> anyhow::Result<Self> {
|
||||
let text = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("read config {}", path.display()))?;
|
||||
let mut cfg: Self =
|
||||
toml::from_str(&text).with_context(|| format!("parse config {}", path.display()))?;
|
||||
cfg.validate()?;
|
||||
Ok(cfg)
|
||||
}
|
||||
|
||||
fn validate(&mut self) -> anyhow::Result<()> {
|
||||
if self.endpoints.is_empty() {
|
||||
return Err(anyhow!("config has no [[endpoints]] entries"));
|
||||
}
|
||||
for (i, ep) in self.endpoints.iter().enumerate() {
|
||||
if ep.name.is_empty() {
|
||||
return Err(anyhow!("endpoints[{i}] has empty name"));
|
||||
}
|
||||
if ep.name.contains(':') {
|
||||
return Err(anyhow!(
|
||||
"endpoints[{i}].name '{}' contains ':' which would clash \
|
||||
with the endpoint:model selector syntax",
|
||||
ep.name
|
||||
));
|
||||
}
|
||||
}
|
||||
// Pick a default endpoint if none was named.
|
||||
if self.default_endpoint.is_none() {
|
||||
self.default_endpoint = Some(self.endpoints[0].name.clone());
|
||||
}
|
||||
let default_name = self.default_endpoint.as_deref().unwrap();
|
||||
if !self.endpoints.iter().any(|e| e.name == default_name) {
|
||||
return Err(anyhow!(
|
||||
"default_endpoint '{default_name}' is not declared in [[endpoints]]"
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Look up an endpoint by name. Returns `None` if not configured.
|
||||
pub fn endpoint(&self, name: &str) -> Option<&EndpointConfig> {
|
||||
self.endpoints.iter().find(|e| e.name == name)
|
||||
}
|
||||
|
||||
/// The default endpoint (guaranteed to exist after `validate`).
|
||||
pub fn default_endpoint(&self) -> &EndpointConfig {
|
||||
let name = self
|
||||
.default_endpoint
|
||||
.as_deref()
|
||||
.expect("default_endpoint set by validate");
|
||||
self.endpoint(name)
|
||||
.expect("default_endpoint resolves after validate")
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse an ACP-side `model` field into (endpoint name, raw model id).
|
||||
///
|
||||
/// `helexa:helexa/large` → (`Some("helexa")`, `"helexa/large"`).
|
||||
/// `helexa/large` → (`None`, `"helexa/large"`).
|
||||
///
|
||||
/// The split happens at the FIRST colon. Model ids commonly contain
|
||||
/// `/` (HuggingFace style) but rarely `:`; if a model id ever does, the
|
||||
/// user can quote-prefix with the default endpoint name.
|
||||
pub fn parse_model_selector(input: &str) -> (Option<&str>, &str) {
|
||||
match input.split_once(':') {
|
||||
Some((endpoint, model)) if !endpoint.is_empty() && !model.is_empty() => {
|
||||
(Some(endpoint), model)
|
||||
}
|
||||
_ => (None, input),
|
||||
}
|
||||
}
|
||||
|
||||
fn config_path() -> Option<PathBuf> {
|
||||
if let Ok(override_path) = std::env::var("HELEXA_ACP_CONFIG_PATH") {
|
||||
return Some(PathBuf::from(override_path));
|
||||
}
|
||||
let xdg = std::env::var("XDG_CONFIG_HOME")
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty());
|
||||
let base = xdg.map(PathBuf::from).or_else(|| {
|
||||
std::env::var("HOME")
|
||||
.ok()
|
||||
.map(|h| PathBuf::from(h).join(".config"))
|
||||
})?;
|
||||
Some(base.join("helexa-acp").join("config.toml"))
|
||||
}
|
||||
|
||||
fn join_segments(base: &Url, segments: &[&str]) -> Url {
|
||||
let mut out = base.clone();
|
||||
if let Ok(mut path) = out.path_segments_mut() {
|
||||
path.pop_if_empty().extend(segments.iter().copied());
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn url_join_handles_trailing_slash() {
|
||||
let ep = EndpointConfig {
|
||||
name: "x".into(),
|
||||
base_url: Url::parse("http://h.internal:31313/v1").unwrap(),
|
||||
wire_api: WireApi::OpenAiChat,
|
||||
default_model: None,
|
||||
api_key: None,
|
||||
api_key_env: None,
|
||||
max_tokens: None,
|
||||
context_window: None,
|
||||
};
|
||||
assert_eq!(
|
||||
ep.chat_completions_url().as_str(),
|
||||
"http://h.internal:31313/v1/chat/completions"
|
||||
);
|
||||
assert_eq!(
|
||||
ep.models_url().as_str(),
|
||||
"http://h.internal:31313/v1/models"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_model_selector() {
|
||||
assert_eq!(
|
||||
parse_model_selector("helexa:helexa/large"),
|
||||
(Some("helexa"), "helexa/large")
|
||||
);
|
||||
assert_eq!(parse_model_selector("helexa/large"), (None, "helexa/large"));
|
||||
assert_eq!(parse_model_selector("gpt-5"), (None, "gpt-5"));
|
||||
// Edge case: a leading colon → no endpoint.
|
||||
assert_eq!(parse_model_selector(":gpt-5"), (None, ":gpt-5"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_fallback_builds_single_endpoint() {
|
||||
// Don't actually set env vars (would race with other tests);
|
||||
// just confirm the default path constructs cleanly.
|
||||
unsafe {
|
||||
std::env::remove_var("HELEXA_ACP_BASE_URL");
|
||||
std::env::remove_var("HELEXA_ACP_MODEL");
|
||||
std::env::remove_var("HELEXA_ACP_API_KEY");
|
||||
}
|
||||
let cfg = Config::from_env().unwrap();
|
||||
assert_eq!(cfg.endpoints.len(), 1);
|
||||
assert_eq!(cfg.endpoints[0].name, "default");
|
||||
assert_eq!(cfg.endpoints[0].base_url.as_str(), DEFAULT_BASE_URL);
|
||||
assert_eq!(
|
||||
cfg.endpoints[0].default_model.as_deref(),
|
||||
Some(DEFAULT_MODEL)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn toml_parses_multi_endpoint() {
|
||||
let toml_text = r#"
|
||||
default_endpoint = "helexa"
|
||||
|
||||
[[endpoints]]
|
||||
name = "helexa"
|
||||
base_url = "http://hanzalova.internal:31313/v1"
|
||||
default_model = "helexa/large"
|
||||
|
||||
[[endpoints]]
|
||||
name = "openrouter"
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
wire_api = "openai-chat"
|
||||
api_key_env = "OPENROUTER_API_KEY"
|
||||
default_model = "anthropic/claude-opus-4"
|
||||
"#;
|
||||
let mut cfg: Config = toml::from_str(toml_text).unwrap();
|
||||
cfg.validate().unwrap();
|
||||
assert_eq!(cfg.endpoints.len(), 2);
|
||||
assert_eq!(cfg.default_endpoint().name, "helexa");
|
||||
assert_eq!(cfg.endpoints[0].wire_api, WireApi::OpenAiChat);
|
||||
assert_eq!(
|
||||
cfg.endpoints[1].api_key_env.as_deref(),
|
||||
Some("OPENROUTER_API_KEY")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_rejects_colon_in_endpoint_name() {
|
||||
let toml_text = r#"
|
||||
[[endpoints]]
|
||||
name = "bad:name"
|
||||
base_url = "http://x/v1"
|
||||
"#;
|
||||
let mut cfg: Config = toml::from_str(toml_text).unwrap();
|
||||
let err = cfg.validate().unwrap_err();
|
||||
assert!(format!("{err}").contains("clash"));
|
||||
}
|
||||
}
|
||||
@@ -1,145 +0,0 @@
|
||||
//! helexa-acp — Agent Client Protocol bridge for multi-endpoint LLM
|
||||
//! setups (helexa, LM Studio, Ollama, OpenRouter, OpenAI, Anthropic,
|
||||
//! …) with a clean per-endpoint wire-format selector.
|
||||
//!
|
||||
//! Speaks ACP over stdio to an editor client (Zed today). Every
|
||||
//! configured endpoint produces a wire-format-specific
|
||||
//! [`provider::Provider`] implementation; the agent loop in
|
||||
//! [`agent::Agent`] is provider-agnostic, so adding e.g. an Anthropic
|
||||
//! /v1/messages provider doesn't touch `agent.rs`.
|
||||
//!
|
||||
//! Config: `$XDG_CONFIG_HOME/helexa-acp/config.toml` for the multi-
|
||||
//! endpoint case; env vars (`HELEXA_ACP_BASE_URL`, etc.) for the
|
||||
//! single-endpoint case when no config file exists.
|
||||
|
||||
use agent_client_protocol::{Result, Stdio};
|
||||
use std::sync::Arc;
|
||||
|
||||
mod agent;
|
||||
mod compaction;
|
||||
mod config;
|
||||
mod path_util;
|
||||
mod prompt;
|
||||
mod provider;
|
||||
mod qwen3;
|
||||
mod session;
|
||||
mod store;
|
||||
mod tool_runner;
|
||||
mod tools;
|
||||
|
||||
use agent::Agent;
|
||||
use config::{Config, EndpointConfig, WireApi};
|
||||
use provider::{
|
||||
Provider, anthropic_messages::AnthropicMessagesProvider, openai_chat::OpenAIChatProvider,
|
||||
openai_responses::OpenAIResponsesProvider,
|
||||
};
|
||||
|
||||
/// Set up tracing. Logs go to stderr by default — stdout is
|
||||
/// reserved for the JSON-RPC stream. Setting `HELEXA_ACP_LOG_FILE`
|
||||
/// to an absolute path appends logs to that file instead, which is
|
||||
/// the practical way to capture debug output when the agent runs
|
||||
/// under an editor (Zed, etc.) that doesn't surface stderr.
|
||||
///
|
||||
/// `RUST_LOG` still controls levels (e.g. `helexa_acp=debug`).
|
||||
/// ANSI colours are auto-stripped when writing to a file so the log
|
||||
/// is plain text.
|
||||
fn init_tracing() {
|
||||
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info"));
|
||||
|
||||
let log_file = std::env::var("HELEXA_ACP_LOG_FILE")
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty());
|
||||
|
||||
match log_file {
|
||||
Some(path) => match std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&path)
|
||||
{
|
||||
Ok(file) => {
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::sync::Mutex::new(file))
|
||||
.with_env_filter(env_filter)
|
||||
.with_ansi(false)
|
||||
.init();
|
||||
}
|
||||
Err(e) => {
|
||||
// Fall back to stderr and shout. We don't want a
|
||||
// typo'd log path to silence the agent entirely.
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_env_filter(env_filter)
|
||||
.init();
|
||||
tracing::warn!(
|
||||
path = %path,
|
||||
error = %e,
|
||||
"HELEXA_ACP_LOG_FILE could not be opened; using stderr"
|
||||
);
|
||||
}
|
||||
},
|
||||
None => {
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_env_filter(env_filter)
|
||||
.init();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a provider for `endpoint` according to its declared
|
||||
/// `wire_api`. Future wire types (OpenAI Responses, Anthropic
|
||||
/// /v1/messages, Ollama native) slot in here without changing the
|
||||
/// caller.
|
||||
fn build_provider(endpoint: EndpointConfig) -> anyhow::Result<Arc<dyn Provider>> {
|
||||
match endpoint.wire_api {
|
||||
WireApi::OpenAiChat => Ok(Arc::new(OpenAIChatProvider::new(endpoint)?)),
|
||||
WireApi::OpenAiResponses => Ok(Arc::new(OpenAIResponsesProvider::new(endpoint)?)),
|
||||
WireApi::AnthropicMessages => Ok(Arc::new(AnthropicMessagesProvider::new(endpoint)?)),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
init_tracing();
|
||||
|
||||
let cfg = Config::load()
|
||||
.map_err(|e| agent_client_protocol::util::internal_error(format!("config: {e:#}")))?;
|
||||
tracing::info!(
|
||||
endpoints = cfg.endpoints.len(),
|
||||
default_endpoint = %cfg.default_endpoint().name,
|
||||
default_model = ?cfg.default_endpoint().default_model,
|
||||
"helexa-acp starting"
|
||||
);
|
||||
|
||||
// Build a provider for each configured endpoint up-front. Cheap —
|
||||
// just sets up a reqwest::Client and resolves the API key — and
|
||||
// surfaces config mistakes (missing API key env var, unsupported
|
||||
// wire_api) before the editor even sends an initialize request.
|
||||
let mut providers: Vec<Arc<dyn Provider>> = Vec::with_capacity(cfg.endpoints.len());
|
||||
for endpoint in &cfg.endpoints {
|
||||
match build_provider(endpoint.clone()) {
|
||||
Ok(p) => {
|
||||
tracing::info!(
|
||||
endpoint = %endpoint.name,
|
||||
base_url = %endpoint.base_url,
|
||||
wire_api = ?endpoint.wire_api,
|
||||
"registered provider"
|
||||
);
|
||||
providers.push(p);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
endpoint = %endpoint.name,
|
||||
error = %format!("{e:#}"),
|
||||
"skipping endpoint with invalid config"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let agent = Agent::new(&cfg, providers)
|
||||
.await
|
||||
.map_err(|e| agent_client_protocol::util::internal_error(format!("agent: {e:#}")))?;
|
||||
agent.serve(Stdio::new()).await
|
||||
}
|
||||
@@ -1,192 +0,0 @@
|
||||
//! Path expansion shared across every tool that takes a path.
|
||||
//!
|
||||
//! Models often emit shell-style paths like `~/git/repo/file.rs` or
|
||||
//! `$HOME/notes.md`. ACP's `fs/read_text_file` and friends — and our
|
||||
//! own local `std::fs` reads — both want a real absolute path; the
|
||||
//! `~` / `$HOME` forms reach them as literal strings and the open
|
||||
//! fails. The tool schemas already document "absolute path" but in
|
||||
//! practice the model slips up often enough that handling it
|
||||
//! server-side is the difference between "works" and "the agent is
|
||||
//! brittle".
|
||||
//!
|
||||
//! Scope is deliberately small:
|
||||
//!
|
||||
//! - `~` and `~/` (current user only — `~user` lookups would require
|
||||
//! pulling in passwd parsing).
|
||||
//! - `$HOME` and `$HOME/`.
|
||||
//!
|
||||
//! Any other shell variable (`$PWD`, `${HOME}`, …) passes through
|
||||
//! unchanged. The shell already expands them inside `bash` tool
|
||||
//! commands; for the file-tool argument fields, we deliberately
|
||||
//! limit the set so the behaviour is predictable.
|
||||
//!
|
||||
//! Falls back to the input path verbatim when `HOME` is unset
|
||||
//! (stripped-down container env). That preserves the "no surprise
|
||||
//! mutations" rule — never invent a path the caller didn't ask for.
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Process-global lock for tests that mutate `HOME`. Anyone in the
|
||||
/// crate touching `HOME` must hold this for the duration of the
|
||||
/// read-modify-restore window — otherwise concurrent `cargo test`
|
||||
/// workers race and flake.
|
||||
///
|
||||
/// Only built into the test binaries. Production code never mutates
|
||||
/// env vars.
|
||||
#[cfg(test)]
|
||||
pub(crate) static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
|
||||
|
||||
/// Expand `~`, `~/`, `$HOME`, and `$HOME/` prefixes against the
|
||||
/// current user's home directory. All other inputs pass through
|
||||
/// unchanged.
|
||||
///
|
||||
/// Returns the input verbatim if `HOME` isn't set in the env.
|
||||
pub fn expand_path(input: &Path) -> PathBuf {
|
||||
let Some(s) = input.to_str() else {
|
||||
return input.to_path_buf();
|
||||
};
|
||||
let Ok(home) = std::env::var("HOME") else {
|
||||
return input.to_path_buf();
|
||||
};
|
||||
let home = PathBuf::from(home);
|
||||
if s == "~" || s == "$HOME" {
|
||||
return home;
|
||||
}
|
||||
if let Some(rest) = s.strip_prefix("~/") {
|
||||
return home.join(rest);
|
||||
}
|
||||
if let Some(rest) = s.strip_prefix("$HOME/") {
|
||||
return home.join(rest);
|
||||
}
|
||||
input.to_path_buf()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Set HOME for the duration of the test. Tests using this run
|
||||
/// serially under the crate-wide [`ENV_LOCK`] because env
|
||||
/// mutation isn't thread-safe — `cargo test` parallel workers
|
||||
/// would race without it.
|
||||
fn with_home<F: FnOnce()>(home: &str, body: F) {
|
||||
let _g = ENV_LOCK.lock().unwrap();
|
||||
let prior = std::env::var("HOME").ok();
|
||||
// SAFETY: tests touch process-global env. The mutex
|
||||
// serialises access; sub-threads in other test modules
|
||||
// touching HOME aren't expected (none in this crate).
|
||||
unsafe {
|
||||
std::env::set_var("HOME", home);
|
||||
}
|
||||
body();
|
||||
unsafe {
|
||||
match prior {
|
||||
Some(p) => std::env::set_var("HOME", p),
|
||||
None => std::env::remove_var("HOME"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expands_tilde_slash() {
|
||||
with_home("/home/me", || {
|
||||
assert_eq!(
|
||||
expand_path(Path::new("~/git/repo/file.rs")),
|
||||
PathBuf::from("/home/me/git/repo/file.rs")
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expands_bare_tilde() {
|
||||
with_home("/home/me", || {
|
||||
assert_eq!(expand_path(Path::new("~")), PathBuf::from("/home/me"));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expands_dollar_home_slash() {
|
||||
with_home("/home/me", || {
|
||||
assert_eq!(
|
||||
expand_path(Path::new("$HOME/notes.md")),
|
||||
PathBuf::from("/home/me/notes.md")
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expands_bare_dollar_home() {
|
||||
with_home("/home/me", || {
|
||||
assert_eq!(expand_path(Path::new("$HOME")), PathBuf::from("/home/me"));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn absolute_path_passes_through() {
|
||||
with_home("/home/me", || {
|
||||
assert_eq!(
|
||||
expand_path(Path::new("/etc/hostname")),
|
||||
PathBuf::from("/etc/hostname")
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn relative_path_passes_through() {
|
||||
with_home("/home/me", || {
|
||||
assert_eq!(
|
||||
expand_path(Path::new("src/main.rs")),
|
||||
PathBuf::from("src/main.rs")
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tilde_user_form_not_expanded() {
|
||||
// ~other is shell sugar for /home/other and would require
|
||||
// passwd parsing to resolve. Out of scope — pass it
|
||||
// through and let the open fail with a clear error.
|
||||
with_home("/home/me", || {
|
||||
assert_eq!(
|
||||
expand_path(Path::new("~other/x")),
|
||||
PathBuf::from("~other/x")
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_home_env_passes_through() {
|
||||
// Share the same crate-wide lock as `with_home` — otherwise
|
||||
// a parallel test setting HOME races this clear-and-assert
|
||||
// window.
|
||||
let _g = ENV_LOCK.lock().unwrap();
|
||||
let prior = std::env::var("HOME").ok();
|
||||
// SAFETY: serialised by LOCK above.
|
||||
unsafe {
|
||||
std::env::remove_var("HOME");
|
||||
}
|
||||
assert_eq!(
|
||||
expand_path(Path::new("~/git/repo")),
|
||||
PathBuf::from("~/git/repo")
|
||||
);
|
||||
unsafe {
|
||||
if let Some(p) = prior {
|
||||
std::env::set_var("HOME", p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dollar_other_var_not_expanded() {
|
||||
with_home("/home/me", || {
|
||||
assert_eq!(
|
||||
expand_path(Path::new("$PWD/file")),
|
||||
PathBuf::from("$PWD/file")
|
||||
);
|
||||
assert_eq!(
|
||||
expand_path(Path::new("${HOME}/file")),
|
||||
PathBuf::from("${HOME}/file")
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,274 +0,0 @@
|
||||
//! System prompt assembly.
|
||||
//!
|
||||
//! The system message has two parts:
|
||||
//!
|
||||
//! 1. A short human-readable preamble (working directory, style
|
||||
//! instructions). Either the built-in [`DEFAULT_PROMPT`] or a
|
||||
//! user-supplied file at `HELEXA_ACP_SYSTEM_PROMPT_PATH` /
|
||||
//! `system_prompt_path`. `{cwd}` is substituted in both.
|
||||
//! 2. A `# Tools` block in Qwen3 Hermes format (see [`crate::qwen3`])
|
||||
//! describing the available functions. This is what makes the
|
||||
//! model actually call them — neuron/cortex don't honour the
|
||||
//! OpenAI `tools` API field, so the tool list has to live in the
|
||||
//! prompt itself.
|
||||
|
||||
use agent_client_protocol::schema::SessionModeId;
|
||||
use anyhow::Context;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::provider::ToolSpec;
|
||||
use crate::qwen3;
|
||||
use crate::session::MODE_PLAN;
|
||||
|
||||
const DEFAULT_PROMPT: &str = "\
|
||||
You are helexa-acp, a coding assistant working inside an editor.
|
||||
|
||||
Working directory: {cwd}
|
||||
|
||||
Use the tools described below whenever the user's request involves
|
||||
looking at or modifying files, or running commands. Do not ask the
|
||||
user to paste file contents you could read yourself. All file paths
|
||||
must be absolute. Writes and shell commands may prompt the user for
|
||||
permission depending on the session mode.
|
||||
|
||||
Be concise; the user is reading your output in an editor pane.";
|
||||
|
||||
/// Build the system prompt for a session.
|
||||
///
|
||||
/// - `cwd`: session working directory (substituted for `{cwd}` in
|
||||
/// the preamble — both the default and any user-supplied template).
|
||||
/// - `override_path`: path to a user-supplied template, already
|
||||
/// resolved by [`crate::config::Config`]. The `# Tools` block is
|
||||
/// appended *after* the user's template so a custom preamble
|
||||
/// still gets the tool descriptions the model needs.
|
||||
/// - `tools`: the tools to advertise. Empty list → no `# Tools`
|
||||
/// block is appended at all.
|
||||
/// - `mode`: current session mode. When the mode is [`MODE_PLAN`]
|
||||
/// a plan-mode addendum describing the restrictions and the
|
||||
/// completion menu is appended *after* the `# Tools` block so it
|
||||
/// is the last thing the model reads before user input.
|
||||
/// - `plan_dir`: resolved plan directory for the cwd. Only consulted
|
||||
/// when `mode == MODE_PLAN`. `None` means the plan directory could
|
||||
/// not be resolved (no `HOME` / `XDG_DATA_HOME`) — the addendum
|
||||
/// still renders but with a placeholder so the model knows to
|
||||
/// surface the error to the user rather than guess a path.
|
||||
pub fn build_system_prompt(
|
||||
cwd: &Path,
|
||||
override_path: Option<&Path>,
|
||||
tools: &[ToolSpec],
|
||||
mode: &SessionModeId,
|
||||
plan_dir: Option<&Path>,
|
||||
) -> anyhow::Result<String> {
|
||||
let template = match override_path {
|
||||
Some(path) => std::fs::read_to_string(path)
|
||||
.with_context(|| format!("read system prompt from {}", path.display()))?,
|
||||
None => DEFAULT_PROMPT.to_string(),
|
||||
};
|
||||
let mut prompt = template.replace("{cwd}", &cwd.display().to_string());
|
||||
prompt.push_str(&qwen3::render_tool_block(tools));
|
||||
if mode.0.as_ref() == MODE_PLAN {
|
||||
prompt.push_str(&render_plan_mode_block(plan_dir));
|
||||
}
|
||||
Ok(prompt)
|
||||
}
|
||||
|
||||
/// Plan-mode instruction block. Tells the model:
|
||||
///
|
||||
/// 1. Where it may write — only inside `plan_dir`.
|
||||
/// 2. What it may *not* do — bash is disabled; writes outside
|
||||
/// `plan_dir` are refused by the runtime.
|
||||
/// 3. How to finish — emit the 3-option menu so the user can
|
||||
/// switch modes and either kick off implementation (with or
|
||||
/// without permission prompts) or keep iterating on the plan.
|
||||
fn render_plan_mode_block(plan_dir: Option<&Path>) -> String {
|
||||
let plan_path = plan_dir
|
||||
.map(|p| p.display().to_string())
|
||||
.unwrap_or_else(|| "<plan directory could not be resolved — tell the user>".to_string());
|
||||
format!(
|
||||
"\n\n# Plan mode\n\
|
||||
\n\
|
||||
You are in **plan mode**. Your task is to draft a written\n\
|
||||
implementation plan for the user; you must NOT modify any\n\
|
||||
project files or run shell commands.\n\
|
||||
\n\
|
||||
Rules in plan mode:\n\
|
||||
\n\
|
||||
- `read_file` and `list_dir` are unrestricted — use them to\n\
|
||||
explore the codebase as needed.\n\
|
||||
- `write_file` and `edit_file` are allowed ONLY under the\n\
|
||||
plan directory: `{plan_path}`. The runtime will refuse any\n\
|
||||
write outside it.\n\
|
||||
- `bash` is disabled. Do not call it.\n\
|
||||
\n\
|
||||
Write the plan as one or more Markdown files under\n\
|
||||
`{plan_path}`. Use descriptive filenames\n\
|
||||
(`01-overview.md`, `02-data-model.md`, etc.). It is fine to\n\
|
||||
iterate — overwrite the file when you refine a section.\n\
|
||||
\n\
|
||||
When the plan is complete, do NOT begin implementation.\n\
|
||||
Instead, end your turn with this menu, verbatim, so the\n\
|
||||
user can choose how to proceed:\n\
|
||||
\n\
|
||||
---\n\
|
||||
**Plan complete.** To proceed, switch the session mode in\n\
|
||||
the agent dropdown and send a follow-up message:\n\
|
||||
\n\
|
||||
1. **Bypass Permissions** — implement the plan now, skipping\n\
|
||||
per-tool permission prompts.\n\
|
||||
2. **Default** — implement the plan now, prompting before\n\
|
||||
each write or shell command.\n\
|
||||
3. **Plan** (stay here) — refine the plan; reply with the\n\
|
||||
change you want and I will revise it.\n\
|
||||
---\n"
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::session::{MODE_DEFAULT, MODE_PLAN};
|
||||
use std::io::Write;
|
||||
|
||||
fn default_mode() -> SessionModeId {
|
||||
SessionModeId::new(MODE_DEFAULT)
|
||||
}
|
||||
fn plan_mode() -> SessionModeId {
|
||||
SessionModeId::new(MODE_PLAN)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_prompt_substitutes_cwd() {
|
||||
let prompt =
|
||||
build_system_prompt(Path::new("/home/me/proj"), None, &[], &default_mode(), None)
|
||||
.unwrap();
|
||||
assert!(
|
||||
prompt.contains("/home/me/proj"),
|
||||
"cwd not interpolated: {prompt}"
|
||||
);
|
||||
assert!(prompt.contains("helexa-acp"));
|
||||
assert!(
|
||||
!prompt.contains("{cwd}"),
|
||||
"left-over placeholder in default prompt"
|
||||
);
|
||||
// With no tools, the # Tools block is absent.
|
||||
assert!(!prompt.contains("# Tools"));
|
||||
// Default mode does not get the plan-mode addendum.
|
||||
assert!(!prompt.contains("# Plan mode"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tools_are_appended_in_hermes_format() {
|
||||
let spec = ToolSpec {
|
||||
name: "read_file".into(),
|
||||
description: "Read a file.".into(),
|
||||
parameters: serde_json::json!({"type":"object","properties":{}, "required":[]}),
|
||||
};
|
||||
let prompt =
|
||||
build_system_prompt(Path::new("/x"), None, &[spec], &default_mode(), None).unwrap();
|
||||
assert!(prompt.contains("# Tools"));
|
||||
assert!(prompt.contains("<tools>"));
|
||||
assert!(prompt.contains("\"name\":\"read_file\""));
|
||||
assert!(prompt.contains("<tool_call>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn override_path_is_read_and_templated() {
|
||||
let mut tmp = tempfile_in_target("prompt.txt");
|
||||
tmp.write_all(b"custom prompt for {cwd} only").unwrap();
|
||||
tmp.flush().unwrap();
|
||||
|
||||
let path = tmp.path().to_path_buf();
|
||||
drop(tmp);
|
||||
|
||||
let prompt = build_system_prompt(
|
||||
Path::new("/etc"),
|
||||
Some(path.as_path()),
|
||||
&[],
|
||||
&default_mode(),
|
||||
None,
|
||||
)
|
||||
.expect("read override");
|
||||
assert_eq!(prompt, "custom prompt for /etc only");
|
||||
|
||||
let _ = std::fs::remove_file(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_override_path_errors() {
|
||||
let err = build_system_prompt(
|
||||
Path::new("/tmp"),
|
||||
Some(Path::new("/definitely/not/a/real/path")),
|
||||
&[],
|
||||
&default_mode(),
|
||||
None,
|
||||
)
|
||||
.unwrap_err();
|
||||
assert!(format!("{err:#}").contains("read system prompt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_mode_addendum_includes_plan_dir_and_menu() {
|
||||
let plan_dir = Path::new("/home/me/.local/share/helexa-acp/plans/proj-deadbeef");
|
||||
let prompt = build_system_prompt(
|
||||
Path::new("/home/me/proj"),
|
||||
None,
|
||||
&[],
|
||||
&plan_mode(),
|
||||
Some(plan_dir),
|
||||
)
|
||||
.unwrap();
|
||||
assert!(prompt.contains("# Plan mode"));
|
||||
assert!(
|
||||
prompt.contains(plan_dir.to_str().unwrap()),
|
||||
"plan dir not interpolated: {prompt}"
|
||||
);
|
||||
// The 3-option menu must be present so the model emits it verbatim.
|
||||
assert!(prompt.contains("Bypass Permissions"));
|
||||
assert!(prompt.contains("**Default**"));
|
||||
assert!(prompt.contains("3. **Plan**"));
|
||||
// Bash disabled instruction must be present.
|
||||
assert!(prompt.contains("`bash` is disabled"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_mode_addendum_handles_unresolved_plan_dir() {
|
||||
let prompt =
|
||||
build_system_prompt(Path::new("/home/me/proj"), None, &[], &plan_mode(), None).unwrap();
|
||||
assert!(prompt.contains("# Plan mode"));
|
||||
assert!(prompt.contains("could not be resolved"));
|
||||
}
|
||||
|
||||
/// Tiny temp-file helper that doesn't pull in the `tempfile` crate.
|
||||
/// Writes under `target/` so it's cleaned up by `cargo clean`.
|
||||
fn tempfile_in_target(name: &str) -> TempHandle {
|
||||
let base = std::env::var("CARGO_TARGET_TMPDIR")
|
||||
.ok()
|
||||
.map(std::path::PathBuf::from)
|
||||
.unwrap_or_else(std::env::temp_dir);
|
||||
let _ = std::fs::create_dir_all(&base);
|
||||
let pid = std::process::id();
|
||||
let path = base.join(format!("helexa-acp-{pid}-{name}"));
|
||||
let file = std::fs::File::create(&path).expect("create temp file");
|
||||
TempHandle { file, path }
|
||||
}
|
||||
|
||||
struct TempHandle {
|
||||
file: std::fs::File,
|
||||
path: std::path::PathBuf,
|
||||
}
|
||||
|
||||
impl TempHandle {
|
||||
fn path(&self) -> &Path {
|
||||
&self.path
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for TempHandle {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
self.file.write(buf)
|
||||
}
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
self.file.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,230 +0,0 @@
|
||||
//! Provider trait — the seam between the ACP-side agent loop and
|
||||
//! whatever wire protocol an endpoint actually speaks.
|
||||
//!
|
||||
//! Every concrete provider (OpenAI chat completions, OpenAI Responses,
|
||||
//! Anthropic /v1/messages, Ollama native, …) implements
|
||||
//! [`Provider`]. The agent constructs a [`CompletionRequest`] using
|
||||
//! provider-agnostic types and consumes a stream of
|
||||
//! [`CompletionEvent`]s — neither end knows which wire format is on
|
||||
//! the other side of the trait.
|
||||
//!
|
||||
//! Day-1 provider: [`openai_chat::OpenAIChatProvider`]. Day-N
|
||||
//! providers slot in without touching `agent.rs`.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
pub mod anthropic_messages;
|
||||
pub mod openai_chat;
|
||||
pub mod openai_responses;
|
||||
|
||||
/// Provider-agnostic LLM endpoint. Implementations translate between
|
||||
/// [`CompletionRequest`] / [`CompletionEvent`] and whatever wire
|
||||
/// format their endpoint speaks.
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
/// Endpoint name as configured by the user (e.g. `"helexa"`,
|
||||
/// `"openrouter"`). Used in logs and in the `endpoint:model`
|
||||
/// selector.
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// List models available at this endpoint. Used to build the
|
||||
/// model-picker dropdown in editor clients (Stage 4). Should
|
||||
/// return quickly (cache if necessary).
|
||||
#[allow(dead_code)]
|
||||
async fn list_models(&self) -> anyhow::Result<Vec<ModelInfo>>;
|
||||
|
||||
/// Run a chat completion. Returns a stream of provider-agnostic
|
||||
/// events. The stream stops when the upstream finishes, when
|
||||
/// `cancel` is fired, or when the stream is dropped.
|
||||
async fn complete(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
cancel: CancellationToken,
|
||||
) -> anyhow::Result<BoxStream<'static, anyhow::Result<CompletionEvent>>>;
|
||||
}
|
||||
|
||||
/// One model exposed by a provider. Constructed by `list_models` —
|
||||
/// Stage 4 is when the agent loop starts consuming it for the
|
||||
/// model-picker dropdown.
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
pub id: String,
|
||||
/// Human-friendly name, if the endpoint exposes one. Otherwise
|
||||
/// `id` is used as the display name.
|
||||
#[serde(default)]
|
||||
pub display_name: Option<String>,
|
||||
}
|
||||
|
||||
/// Inputs to a completion. Provider-agnostic — concrete providers
|
||||
/// translate this into their wire format.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompletionRequest {
|
||||
/// Endpoint-local model id (without the `endpoint:` prefix).
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
/// Tools the model is allowed to call. Empty list means no tool
|
||||
/// support advertised.
|
||||
pub tools: Vec<ToolSpec>,
|
||||
pub temperature: Option<f64>,
|
||||
pub top_p: Option<f64>,
|
||||
pub max_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub content: MessageContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Role {
|
||||
System,
|
||||
User,
|
||||
Assistant,
|
||||
/// Tool result message. Provider impls turn this into whatever
|
||||
/// shape the upstream wire format wants (OpenAI uses
|
||||
/// `role: "tool"` + `tool_call_id`; Anthropic uses content blocks).
|
||||
/// Stage 3 (tools) constructs this; Stage 2 never does.
|
||||
Tool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum MessageContent {
|
||||
/// Plain text turn (system / user / assistant). Struct variant
|
||||
/// rather than newtype so the persisted JSON has an explicit
|
||||
/// `text` field — that lets us use internal tagging on the
|
||||
/// enum, which is incompatible with newtype-of-primitive
|
||||
/// variants.
|
||||
Text { text: String },
|
||||
/// Mixed text + image user turn. Stage 5 introduces this when
|
||||
/// Zed sends an `ImageContent` block alongside the user's prompt.
|
||||
/// Providers that don't support vision should down-convert by
|
||||
/// dropping image parts and concatenating text parts.
|
||||
MultiPart { parts: Vec<MessagePart> },
|
||||
/// Assistant turn that called one or more tools. Stage 3 starts
|
||||
/// constructing this when the provider stream yields a
|
||||
/// `ToolCallStart` / `ToolCallArgsDelta` sequence.
|
||||
ToolCalls {
|
||||
/// Optional text the assistant said alongside the tool calls.
|
||||
text: Option<String>,
|
||||
calls: Vec<ToolCall>,
|
||||
},
|
||||
/// Tool result. `tool_call_id` matches the assistant's call id.
|
||||
/// Stage 3 constructs this after the tool runner finishes.
|
||||
ToolResult {
|
||||
tool_call_id: String,
|
||||
content: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// One part of a [`MessageContent::MultiPart`] message.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum MessagePart {
|
||||
Text { text: String },
|
||||
Image(ImageData),
|
||||
}
|
||||
|
||||
/// Inline image attachment. `data` is base64-encoded raw image
|
||||
/// bytes; the encoder constructs an `image_url` data URI from it
|
||||
/// at request time. `uri` carries any pointer the client supplied
|
||||
/// (e.g. `file:///tmp/x.png`) — we keep it on the message for
|
||||
/// debugging / future providers but the OpenAI encoder ignores it
|
||||
/// when `data` is present (data wins, since it round-trips through
|
||||
/// every wire format).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ImageData {
|
||||
pub mime_type: String,
|
||||
/// Base64-encoded image bytes (no `data:` prefix, no padding
|
||||
/// stripped — exactly what `ImageContent.data` carried).
|
||||
pub data: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub uri: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
/// Provider-assigned id that ties the call to its result. The
|
||||
/// Qwen3 wire format we use today doesn't carry this on the
|
||||
/// model side (calls and results are matched positionally inside
|
||||
/// a turn), so the field looks unused in the prod build — but it
|
||||
/// flows through to `MessageContent::ToolResult.tool_call_id` for
|
||||
/// history bookkeeping and a future strict-OpenAI backend will
|
||||
/// consume it directly.
|
||||
#[allow(dead_code)]
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
/// JSON-encoded arguments. Kept as a string because providers
|
||||
/// stream argument bytes incrementally and only validate at the
|
||||
/// end; the agent decodes once the call is complete.
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolSpec {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
/// JSON Schema of the arguments object.
|
||||
pub parameters: Value,
|
||||
}
|
||||
|
||||
/// Events emitted by a provider during a streaming completion.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CompletionEvent {
|
||||
/// Incremental visible text from the assistant.
|
||||
TextDelta(String),
|
||||
/// Incremental "reasoning" / thought text, if the model emits one
|
||||
/// (e.g. Qwen3 with `<think>` tags surfaced as a separate stream,
|
||||
/// or OpenAI reasoning models).
|
||||
ReasoningDelta(String),
|
||||
/// A new tool call has started. Stage 2 ignores the payload; the
|
||||
/// agent loop in Stage 3 reads `index` to correlate with
|
||||
/// [`Self::ToolCallArgsDelta`], `id` for the eventual tool-result
|
||||
/// turn, and `name` to dispatch the runner.
|
||||
#[allow(dead_code)]
|
||||
ToolCallStart {
|
||||
index: usize,
|
||||
id: String,
|
||||
name: String,
|
||||
},
|
||||
/// More argument bytes for a tool call already announced via
|
||||
/// [`Self::ToolCallStart`]. Stage 2 ignores; Stage 3 accumulates
|
||||
/// the bytes by `index` until the call's arguments are complete.
|
||||
#[allow(dead_code)]
|
||||
ToolCallArgsDelta { index: usize, args_delta: String },
|
||||
/// A `<tool_call>` block whose JSON couldn't be parsed even with
|
||||
/// the qwen3 module's repair attempts. The agent surfaces this
|
||||
/// as a Failed `SessionUpdate::ToolCall` card with the raw body
|
||||
/// visible (so the editor renders structured failure UI rather
|
||||
/// than dumping the body inline in the message pane), and feeds
|
||||
/// a synthetic tool-error message back into history so the
|
||||
/// model can self-correct on the next round.
|
||||
MalformedToolCall { raw: String },
|
||||
/// Stream finished. Carries the upstream `finish_reason` if it
|
||||
/// gave one (`"stop"`, `"length"`, `"tool_calls"`, …).
|
||||
Finish { reason: Option<String> },
|
||||
/// Final usage stats, if the provider supplied them. Stage 2
|
||||
/// matches the variant to drop it; Stage 6b (token metrics) is
|
||||
/// when the payload starts being read.
|
||||
#[allow(dead_code)]
|
||||
Usage(UsageStats),
|
||||
}
|
||||
|
||||
/// Token accounting reported by the provider at the end of a stream.
|
||||
/// Stage 2 doesn't surface usage anywhere — the stable `PromptResponse`
|
||||
/// has no usage field, and the unstable variant is gated. Stage 6b
|
||||
/// turns these on with Prometheus metrics.
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct UsageStats {
|
||||
pub prompt_tokens: u64,
|
||||
pub completion_tokens: u64,
|
||||
pub total_tokens: u64,
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,987 +0,0 @@
|
||||
//! OpenAI Responses API (`POST /v1/responses`) provider.
|
||||
//!
|
||||
//! Mirror image of [`super::openai_chat`]: same `Provider` trait
|
||||
//! impl, same back-pressured SSE decoder, but speaking OpenAI's
|
||||
//! newer Responses surface instead of chat completions.
|
||||
//!
|
||||
//! Differences from the chat provider, all contained in this file:
|
||||
//!
|
||||
//! - **Request encoding**: history flattens into an `input` array
|
||||
//! of typed items (`message`, `function_call`, `function_call_output`)
|
||||
//! plus a top-level `instructions` field for the system prompt.
|
||||
//! Multi-part user content stays in the same `[{type:"input_text"},
|
||||
//! {type:"input_image"}]` shape neuron's `request_to_chat` already
|
||||
//! accepts.
|
||||
//! - **Streaming decoder**: events are named (`response.created`,
|
||||
//! `response.output_text.delta`, `response.completed`, …) carried
|
||||
//! on the SSE `event:` line. The chat path's `[DONE]` terminator
|
||||
//! doesn't apply; the stream ends after `response.completed`.
|
||||
//! - **Tool calls** plumb through the `response.output_item.added`
|
||||
//! (item type `function_call`) → `response.function_call_arguments.delta`
|
||||
//! → `response.function_call_arguments.done` event sequence. The
|
||||
//! neuron candle harness doesn't synthesize these yet (tracked as
|
||||
//! issue #6), but the decoder is wired so the day the upstream
|
||||
//! does, downstream `CompletionEvent::ToolCall*` plumbing just
|
||||
//! works.
|
||||
//!
|
||||
//! Tool-name handling: the model knows its tool descriptions via
|
||||
//! the [`crate::qwen3`] system-prompt block exactly the way the chat
|
||||
//! provider does. We don't echo them in the request body because
|
||||
//! neuron currently ignores `tools` on /v1/responses (same as on
|
||||
//! /v1/chat/completions). Once neuron honours request-side tool
|
||||
//! definitions, both providers add them in the same place.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::{Stream, StreamExt, stream::BoxStream};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use super::{
|
||||
CompletionEvent, CompletionRequest, Message, MessageContent, MessagePart, ModelInfo, Provider,
|
||||
Role, UsageStats,
|
||||
};
|
||||
use crate::config::EndpointConfig;
|
||||
|
||||
pub struct OpenAIResponsesProvider {
|
||||
endpoint: EndpointConfig,
|
||||
#[allow(dead_code)] // Read in `complete()`'s HTTP path; tests don't stand up a server.
|
||||
api_key: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
impl OpenAIResponsesProvider {
|
||||
pub fn new(endpoint: EndpointConfig) -> anyhow::Result<Self> {
|
||||
let api_key = endpoint.resolve_api_key()?;
|
||||
let http = reqwest::Client::builder()
|
||||
// Same generous timeout as the chat provider: cortex may
|
||||
// need to cold-load a model before serving the first
|
||||
// chunk, which can be tens of seconds. Cancellation
|
||||
// handles early termination, not timeout.
|
||||
.timeout(std::time::Duration::from_secs(600))
|
||||
.build()?;
|
||||
Ok(Self {
|
||||
endpoint,
|
||||
api_key,
|
||||
http,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OpenAIResponsesProvider {
|
||||
fn name(&self) -> &str {
|
||||
&self.endpoint.name
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> anyhow::Result<Vec<ModelInfo>> {
|
||||
let mut req = self.http.get(self.endpoint.models_url());
|
||||
if let Some(key) = &self.api_key {
|
||||
req = req.bearer_auth(key);
|
||||
}
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{} list_models: {e}", self.endpoint.name))?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!(
|
||||
"{} list_models returned {}: {}",
|
||||
self.endpoint.name,
|
||||
status,
|
||||
body
|
||||
);
|
||||
}
|
||||
let body: WireModelsResponse = resp.json().await?;
|
||||
Ok(body
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|m| ModelInfo {
|
||||
id: m.id,
|
||||
display_name: None,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn complete(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
cancel: CancellationToken,
|
||||
) -> anyhow::Result<BoxStream<'static, anyhow::Result<CompletionEvent>>> {
|
||||
let body = encode_request(&request);
|
||||
tracing::debug!(
|
||||
endpoint = %self.endpoint.name,
|
||||
url = %self.endpoint.responses_url(),
|
||||
body = %serde_json::to_string(&body).unwrap_or_else(|_| "<unserializable>".into()),
|
||||
"POST /responses"
|
||||
);
|
||||
let mut req = self.http.post(self.endpoint.responses_url()).json(&body);
|
||||
if let Some(key) = &self.api_key {
|
||||
req = req.bearer_auth(key);
|
||||
}
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{} responses send: {e}", self.endpoint.name))?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!(
|
||||
"{} responses returned {}: {}",
|
||||
self.endpoint.name,
|
||||
status,
|
||||
body
|
||||
);
|
||||
}
|
||||
let sse = resp.bytes_stream().eventsource();
|
||||
let stream = decode_stream(sse, cancel);
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
}
|
||||
|
||||
// ── Request encoding ─────────────────────────────────────────────────
|
||||
|
||||
fn encode_request(req: &CompletionRequest) -> Value {
|
||||
// Pull the system messages out of history into a single
|
||||
// `instructions` string — the Responses API expects them there,
|
||||
// not inline as an `input` item. Multiple system messages
|
||||
// concatenate with blank lines so we don't lose ordering.
|
||||
let mut instructions: Vec<String> = Vec::new();
|
||||
let mut input_items: Vec<Value> = Vec::new();
|
||||
for msg in &req.messages {
|
||||
if msg.role == Role::System
|
||||
&& let MessageContent::Text { text } = &msg.content
|
||||
{
|
||||
instructions.push(text.clone());
|
||||
continue;
|
||||
}
|
||||
if let Some(item) = encode_message_as_input_item(msg) {
|
||||
input_items.push(item);
|
||||
}
|
||||
}
|
||||
|
||||
let mut body = json!({
|
||||
"model": req.model,
|
||||
"input": input_items,
|
||||
"stream": true,
|
||||
});
|
||||
if let Value::Object(map) = &mut body {
|
||||
if !instructions.is_empty() {
|
||||
map.insert(
|
||||
"instructions".into(),
|
||||
Value::String(instructions.join("\n\n")),
|
||||
);
|
||||
}
|
||||
if let Some(t) = req.temperature {
|
||||
map.insert("temperature".into(), json!(t));
|
||||
}
|
||||
if let Some(p) = req.top_p {
|
||||
map.insert("top_p".into(), json!(p));
|
||||
}
|
||||
if let Some(m) = req.max_tokens {
|
||||
// Responses calls it `max_output_tokens`; preserve the
|
||||
// semantic (response cap) when we translate.
|
||||
map.insert("max_output_tokens".into(), json!(m));
|
||||
}
|
||||
}
|
||||
body
|
||||
}
|
||||
|
||||
fn encode_message_as_input_item(msg: &Message) -> Option<Value> {
|
||||
match (msg.role, &msg.content) {
|
||||
(Role::System, _) => None, // handled out-of-band as `instructions`
|
||||
(Role::User, MessageContent::Text { text }) => Some(json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": text,
|
||||
})),
|
||||
(Role::User, MessageContent::MultiPart { parts }) => Some(json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": encode_user_parts(parts),
|
||||
})),
|
||||
(Role::Assistant, MessageContent::Text { text }) => Some(json!({
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": text,
|
||||
"annotations": [],
|
||||
}],
|
||||
})),
|
||||
(Role::Assistant, MessageContent::ToolCalls { text, calls }) => {
|
||||
// Assistant turns that called tools become a sequence of
|
||||
// items: an optional `message` (any prose alongside the
|
||||
// call) followed by one `function_call` per call. Mirrors
|
||||
// OpenAI Responses' "each item is one structural slot"
|
||||
// shape.
|
||||
//
|
||||
// We can't return multiple items from one call site, so
|
||||
// we encode this by side-stuffing additional items into a
|
||||
// single composite value and have the caller flatten —
|
||||
// but that complicates the API. Easier: build the array
|
||||
// ourselves in the caller path. For now, emit just the
|
||||
// function_calls (the assistant's prose lives in the next
|
||||
// turn's chat history anyway because the model isn't
|
||||
// looking back at its own previous narration). If the
|
||||
// text is non-empty AND we have calls, we lose the text;
|
||||
// qwen3 rarely emits prose alongside tool calls so this
|
||||
// is a deliberate simplification — revisit if it bites.
|
||||
let _ = text;
|
||||
// Take the first call only for the moment; multi-call
|
||||
// turns would need the caller-flattening above.
|
||||
let call = calls.first()?;
|
||||
Some(json!({
|
||||
"type": "function_call",
|
||||
"call_id": call.id,
|
||||
"name": call.name,
|
||||
"arguments": call.arguments,
|
||||
}))
|
||||
}
|
||||
(
|
||||
Role::Tool,
|
||||
MessageContent::ToolResult {
|
||||
tool_call_id,
|
||||
content,
|
||||
},
|
||||
) => Some(json!({
|
||||
"type": "function_call_output",
|
||||
"call_id": tool_call_id,
|
||||
"output": content,
|
||||
})),
|
||||
(role, content) => {
|
||||
tracing::warn!(
|
||||
?role,
|
||||
?content,
|
||||
"openai_responses: unexpected (role, content) shape"
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_user_parts(parts: &[MessagePart]) -> Value {
|
||||
let items: Vec<Value> = parts
|
||||
.iter()
|
||||
.map(|p| match p {
|
||||
MessagePart::Text { text } => json!({"type": "input_text", "text": text}),
|
||||
MessagePart::Image(img) => json!({
|
||||
"type": "input_image",
|
||||
"image_url": format!("data:{};base64,{}", img.mime_type, img.data),
|
||||
}),
|
||||
})
|
||||
.collect();
|
||||
Value::Array(items)
|
||||
}
|
||||
|
||||
// ── Wire types ──────────────────────────────────────────────────────
|
||||
|
||||
#[allow(dead_code)] // fields read only when list_models runs against a real endpoint
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WireModelsResponse {
|
||||
data: Vec<WireModelObject>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WireModelObject {
|
||||
id: String,
|
||||
}
|
||||
|
||||
// SSE event payload shapes. We only model the fields we care about;
|
||||
// `#[serde(default)]` + `Option` everywhere else lets the upstream
|
||||
// add optional fields without breaking deserialise.
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct OutputItemAddedEvent {
|
||||
#[serde(default)]
|
||||
output_index: u32,
|
||||
item: OutputItem,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum OutputItem {
|
||||
Message {
|
||||
#[serde(default)]
|
||||
id: Option<String>,
|
||||
},
|
||||
FunctionCall {
|
||||
#[serde(default)]
|
||||
id: Option<String>,
|
||||
#[serde(default)]
|
||||
call_id: Option<String>,
|
||||
#[serde(default)]
|
||||
name: Option<String>,
|
||||
/// Some upstreams populate `arguments` already on the
|
||||
/// `output_item.added` event for a fully-buffered tool call
|
||||
/// (i.e. when the model finalised the call before the SSE
|
||||
/// flush). Capture it so we can emit a single args delta.
|
||||
#[serde(default)]
|
||||
arguments: Option<String>,
|
||||
},
|
||||
/// `reasoning`, `web_search_call`, etc. We capture-and-ignore
|
||||
/// any item we don't model; the decoder still emits the
|
||||
/// outer events correctly.
|
||||
#[serde(other)]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct OutputTextDeltaEvent {
|
||||
#[serde(default)]
|
||||
item_id: Option<String>,
|
||||
#[serde(default)]
|
||||
output_index: u32,
|
||||
#[serde(default)]
|
||||
delta: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct FunctionCallArgumentsDeltaEvent {
|
||||
#[serde(default)]
|
||||
item_id: Option<String>,
|
||||
#[serde(default)]
|
||||
output_index: u32,
|
||||
#[serde(default)]
|
||||
delta: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct ResponseCompletedEvent {
|
||||
response: ResponseShell,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct ResponseShell {
|
||||
#[serde(default)]
|
||||
status: Option<String>,
|
||||
#[serde(default)]
|
||||
usage: Option<WireUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct WireUsage {
|
||||
#[serde(default)]
|
||||
input_tokens: u64,
|
||||
#[serde(default)]
|
||||
output_tokens: u64,
|
||||
#[serde(default)]
|
||||
total_tokens: u64,
|
||||
}
|
||||
|
||||
// ── Streaming decoder ───────────────────────────────────────────────
|
||||
|
||||
/// Translate the named-event Responses SSE into the provider-agnostic
|
||||
/// [`CompletionEvent`] stream the agent loop expects. The decoder
|
||||
/// holds per-stream state — output_index → tool-call-index plus
|
||||
/// the next available tool-call slot — so it can fire
|
||||
/// `ToolCallStart` exactly once per item.
|
||||
fn decode_stream<S>(
|
||||
sse: S,
|
||||
cancel: CancellationToken,
|
||||
) -> impl Stream<Item = anyhow::Result<CompletionEvent>>
|
||||
where
|
||||
S: Stream<
|
||||
Item = Result<
|
||||
eventsource_stream::Event,
|
||||
eventsource_stream::EventStreamError<reqwest::Error>,
|
||||
>,
|
||||
> + Send
|
||||
+ 'static,
|
||||
{
|
||||
async_stream::stream! {
|
||||
let mut sse = Box::pin(sse);
|
||||
// Maps an output_index that's a function_call to the tool-call
|
||||
// slot we hand downstream. Lets us correlate later
|
||||
// `function_call_arguments.delta` events back to the index
|
||||
// we already announced on `output_item.added`.
|
||||
let mut tool_index_by_output: HashMap<u32, usize> = HashMap::new();
|
||||
let mut next_tool_index: usize = 0;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
_ = cancel.cancelled() => {
|
||||
tracing::debug!("openai_responses: cancellation requested, ending stream");
|
||||
break;
|
||||
}
|
||||
next = sse.next() => {
|
||||
let Some(event) = next else { break };
|
||||
let event = match event {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
yield Err(anyhow::anyhow!("SSE transport: {e}"));
|
||||
break;
|
||||
}
|
||||
};
|
||||
// Event name lives on `event.event`; data is JSON.
|
||||
let event_name = event.event.as_str();
|
||||
let data = event.data.as_str();
|
||||
match event_name {
|
||||
"response.output_text.delta" => {
|
||||
match serde_json::from_str::<OutputTextDeltaEvent>(data) {
|
||||
Ok(d) if !d.delta.is_empty() => {
|
||||
yield Ok(CompletionEvent::TextDelta(d.delta));
|
||||
}
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
raw = %data,
|
||||
"openai_responses: failed to parse output_text.delta; skipping"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.output_item.added" => {
|
||||
match serde_json::from_str::<OutputItemAddedEvent>(data) {
|
||||
Ok(ev) => {
|
||||
if let OutputItem::FunctionCall {
|
||||
id,
|
||||
call_id,
|
||||
name,
|
||||
arguments,
|
||||
} = ev.item
|
||||
{
|
||||
let idx = next_tool_index;
|
||||
next_tool_index += 1;
|
||||
tool_index_by_output.insert(ev.output_index, idx);
|
||||
// Prefer the user-facing
|
||||
// `call_id` (what gets paired
|
||||
// with tool results) over the
|
||||
// internal item `id` when
|
||||
// both are present. Falls
|
||||
// back to a synthetic id so
|
||||
// history bookkeeping never
|
||||
// breaks.
|
||||
let final_id = call_id
|
||||
.or(id)
|
||||
.unwrap_or_else(|| format!("call_{idx}"));
|
||||
let final_name = name.unwrap_or_default();
|
||||
yield Ok(CompletionEvent::ToolCallStart {
|
||||
index: idx,
|
||||
id: final_id,
|
||||
name: final_name,
|
||||
});
|
||||
// Some upstreams attach the
|
||||
// fully-buffered arguments on
|
||||
// the `output_item.added`
|
||||
// event itself (rare; happens
|
||||
// when the model finalised
|
||||
// before the SSE flush).
|
||||
// Emit as a single args
|
||||
// delta if present.
|
||||
if let Some(args) = arguments
|
||||
&& !args.is_empty()
|
||||
{
|
||||
yield Ok(CompletionEvent::ToolCallArgsDelta {
|
||||
index: idx,
|
||||
args_delta: args,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
raw = %data,
|
||||
"openai_responses: failed to parse output_item.added; skipping"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.function_call_arguments.delta" => {
|
||||
match serde_json::from_str::<FunctionCallArgumentsDeltaEvent>(data) {
|
||||
Ok(ev) => {
|
||||
let Some(&idx) = tool_index_by_output.get(&ev.output_index)
|
||||
else {
|
||||
// Args delta for an item we
|
||||
// never saw an `output_item.added`
|
||||
// for. Could happen if the
|
||||
// upstream reordered events;
|
||||
// log + skip.
|
||||
tracing::warn!(
|
||||
output_index = ev.output_index,
|
||||
"openai_responses: function_call_arguments.delta for unknown output_index"
|
||||
);
|
||||
continue;
|
||||
};
|
||||
if !ev.delta.is_empty() {
|
||||
yield Ok(CompletionEvent::ToolCallArgsDelta {
|
||||
index: idx,
|
||||
args_delta: ev.delta,
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
raw = %data,
|
||||
"openai_responses: failed to parse function_call_arguments.delta; skipping"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.completed" => {
|
||||
// Final event. Pull usage + status off
|
||||
// the response shell. Status maps:
|
||||
// "completed" → no special handling
|
||||
// (caller treats as EndTurn),
|
||||
// "incomplete" → length stop.
|
||||
let (reason, usage) =
|
||||
match serde_json::from_str::<ResponseCompletedEvent>(data) {
|
||||
Ok(ev) => {
|
||||
let reason = match ev.response.status.as_deref() {
|
||||
Some("incomplete") => Some("length".to_string()),
|
||||
_ => Some("stop".to_string()),
|
||||
};
|
||||
let usage = ev.response.usage.map(|u| UsageStats {
|
||||
prompt_tokens: u.input_tokens,
|
||||
completion_tokens: u.output_tokens,
|
||||
total_tokens: u.total_tokens,
|
||||
});
|
||||
(reason, usage)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
raw = %data,
|
||||
"openai_responses: failed to parse response.completed; ending stream with EndTurn"
|
||||
);
|
||||
(Some("stop".to_string()), None)
|
||||
}
|
||||
};
|
||||
if let Some(u) = usage {
|
||||
yield Ok(CompletionEvent::Usage(u));
|
||||
}
|
||||
yield Ok(CompletionEvent::Finish { reason });
|
||||
break;
|
||||
}
|
||||
// Bookkeeping events we don't need to surface:
|
||||
// response.created, response.in_progress,
|
||||
// response.content_part.added/.done,
|
||||
// response.output_text.done,
|
||||
// response.output_item.done,
|
||||
// response.function_call_arguments.done,
|
||||
// response.reasoning_*. Logged at debug for
|
||||
// wire-tracing.
|
||||
other => {
|
||||
tracing::trace!(
|
||||
event = other,
|
||||
"openai_responses: bookkeeping event"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::ToolCall;
|
||||
use crate::provider::{ImageData, MessagePart};
|
||||
use futures::stream;
|
||||
use url::Url;
|
||||
|
||||
fn ep() -> EndpointConfig {
|
||||
EndpointConfig {
|
||||
name: "test".into(),
|
||||
base_url: Url::parse("http://localhost:9999/v1").unwrap(),
|
||||
wire_api: crate::config::WireApi::OpenAiResponses,
|
||||
default_model: None,
|
||||
api_key: None,
|
||||
api_key_env: None,
|
||||
max_tokens: None,
|
||||
context_window: None,
|
||||
}
|
||||
}
|
||||
|
||||
// ── encode_request ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn system_messages_collapse_to_instructions() {
|
||||
let req = CompletionRequest {
|
||||
model: "m".into(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text {
|
||||
text: "you are helpful".into(),
|
||||
},
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text { text: "hi".into() },
|
||||
},
|
||||
],
|
||||
tools: vec![],
|
||||
temperature: Some(0.7),
|
||||
top_p: None,
|
||||
max_tokens: Some(256),
|
||||
};
|
||||
let body = encode_request(&req);
|
||||
assert_eq!(body["model"], "m");
|
||||
assert_eq!(body["instructions"], "you are helpful");
|
||||
assert_eq!(body["stream"], true);
|
||||
assert_eq!(body["max_output_tokens"], 256);
|
||||
assert_eq!(body["temperature"], 0.7);
|
||||
let input = body["input"].as_array().unwrap();
|
||||
// System message NOT echoed in input — it's only in
|
||||
// instructions.
|
||||
assert_eq!(input.len(), 1);
|
||||
assert_eq!(input[0]["type"], "message");
|
||||
assert_eq!(input[0]["role"], "user");
|
||||
assert_eq!(input[0]["content"], "hi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_system_messages_concatenate() {
|
||||
let req = CompletionRequest {
|
||||
model: "m".into(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text {
|
||||
text: "first".into(),
|
||||
},
|
||||
},
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text {
|
||||
text: "second".into(),
|
||||
},
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text { text: "hi".into() },
|
||||
},
|
||||
],
|
||||
tools: vec![],
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
max_tokens: None,
|
||||
};
|
||||
let body = encode_request(&req);
|
||||
assert_eq!(body["instructions"], "first\n\nsecond");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_multipart_becomes_input_parts_array() {
|
||||
let req = CompletionRequest {
|
||||
model: "vl".into(),
|
||||
messages: vec![Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::MultiPart {
|
||||
parts: vec![
|
||||
MessagePart::Text {
|
||||
text: "what's in this?".into(),
|
||||
},
|
||||
MessagePart::Image(ImageData {
|
||||
mime_type: "image/png".into(),
|
||||
data: "AAA=".into(),
|
||||
uri: None,
|
||||
}),
|
||||
],
|
||||
},
|
||||
}],
|
||||
tools: vec![],
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
max_tokens: None,
|
||||
};
|
||||
let body = encode_request(&req);
|
||||
let content = &body["input"][0]["content"].as_array().unwrap().clone();
|
||||
assert_eq!(content.len(), 2);
|
||||
assert_eq!(content[0]["type"], "input_text");
|
||||
assert_eq!(content[0]["text"], "what's in this?");
|
||||
assert_eq!(content[1]["type"], "input_image");
|
||||
assert_eq!(content[1]["image_url"], "data:image/png;base64,AAA=");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn assistant_text_becomes_output_text_content_part() {
|
||||
let req = CompletionRequest {
|
||||
model: "m".into(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text { text: "hi".into() },
|
||||
},
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::Text {
|
||||
text: "hello there".into(),
|
||||
},
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text {
|
||||
text: "more".into(),
|
||||
},
|
||||
},
|
||||
],
|
||||
tools: vec![],
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
max_tokens: None,
|
||||
};
|
||||
let body = encode_request(&req);
|
||||
let input = body["input"].as_array().unwrap();
|
||||
assert_eq!(input.len(), 3);
|
||||
assert_eq!(input[1]["type"], "message");
|
||||
assert_eq!(input[1]["role"], "assistant");
|
||||
assert_eq!(input[1]["content"][0]["type"], "output_text");
|
||||
assert_eq!(input[1]["content"][0]["text"], "hello there");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_calls_and_results_round_trip_via_function_call_items() {
|
||||
let req = CompletionRequest {
|
||||
model: "m".into(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::ToolCalls {
|
||||
text: None,
|
||||
calls: vec![ToolCall {
|
||||
id: "call_42".into(),
|
||||
name: "read_file".into(),
|
||||
arguments: r#"{"path":"/etc/hostname"}"#.into(),
|
||||
}],
|
||||
},
|
||||
},
|
||||
Message {
|
||||
role: Role::Tool,
|
||||
content: MessageContent::ToolResult {
|
||||
tool_call_id: "call_42".into(),
|
||||
content: "host".into(),
|
||||
},
|
||||
},
|
||||
],
|
||||
tools: vec![],
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
max_tokens: None,
|
||||
};
|
||||
let body = encode_request(&req);
|
||||
let input = body["input"].as_array().unwrap();
|
||||
assert_eq!(input.len(), 2);
|
||||
assert_eq!(input[0]["type"], "function_call");
|
||||
assert_eq!(input[0]["call_id"], "call_42");
|
||||
assert_eq!(input[0]["name"], "read_file");
|
||||
assert_eq!(input[0]["arguments"], r#"{"path":"/etc/hostname"}"#);
|
||||
assert_eq!(input[1]["type"], "function_call_output");
|
||||
assert_eq!(input[1]["call_id"], "call_42");
|
||||
assert_eq!(input[1]["output"], "host");
|
||||
}
|
||||
|
||||
// ── decode_stream ───────────────────────────────────────────────
|
||||
|
||||
fn sse_event(name: &str, data: &str) -> eventsource_stream::Event {
|
||||
eventsource_stream::Event {
|
||||
id: String::new(),
|
||||
retry: None,
|
||||
event: name.into(),
|
||||
data: data.into(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_events(
|
||||
items: Vec<eventsource_stream::Event>,
|
||||
) -> Vec<anyhow::Result<CompletionEvent>> {
|
||||
let sse = stream::iter(
|
||||
items
|
||||
.into_iter()
|
||||
.map(Ok::<_, eventsource_stream::EventStreamError<reqwest::Error>>),
|
||||
);
|
||||
let decoded = decode_stream(sse, CancellationToken::new());
|
||||
decoded.collect().await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn decodes_text_then_finish() {
|
||||
let events = collect_events(vec![
|
||||
sse_event("response.created", "{}"),
|
||||
sse_event(
|
||||
"response.output_text.delta",
|
||||
r#"{"item_id":"msg_1","output_index":0,"delta":"hel"}"#,
|
||||
),
|
||||
sse_event(
|
||||
"response.output_text.delta",
|
||||
r#"{"item_id":"msg_1","output_index":0,"delta":"lo"}"#,
|
||||
),
|
||||
sse_event(
|
||||
"response.completed",
|
||||
r#"{"response":{"status":"completed","usage":{"input_tokens":3,"output_tokens":2,"total_tokens":5}}}"#,
|
||||
),
|
||||
])
|
||||
.await;
|
||||
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
|
||||
let mut iter = events.into_iter();
|
||||
assert!(matches!(iter.next(), Some(CompletionEvent::TextDelta(t)) if t == "hel"));
|
||||
assert!(matches!(iter.next(), Some(CompletionEvent::TextDelta(t)) if t == "lo"));
|
||||
assert!(matches!(iter.next(), Some(CompletionEvent::Usage(u)) if u.total_tokens == 5));
|
||||
assert!(matches!(
|
||||
iter.next(),
|
||||
Some(CompletionEvent::Finish { reason: Some(r) }) if r == "stop"
|
||||
));
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_delta_is_dropped() {
|
||||
let events = collect_events(vec![
|
||||
sse_event(
|
||||
"response.output_text.delta",
|
||||
r#"{"item_id":"m","output_index":0,"delta":""}"#,
|
||||
),
|
||||
sse_event(
|
||||
"response.completed",
|
||||
r#"{"response":{"status":"completed"}}"#,
|
||||
),
|
||||
])
|
||||
.await;
|
||||
let mut completion_events = events.into_iter().map(|r| r.unwrap());
|
||||
// First event MUST be the Finish — the empty delta dropped.
|
||||
assert!(matches!(
|
||||
completion_events.next(),
|
||||
Some(CompletionEvent::Finish { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn incomplete_status_maps_to_length_finish_reason() {
|
||||
let events = collect_events(vec![sse_event(
|
||||
"response.completed",
|
||||
r#"{"response":{"status":"incomplete"}}"#,
|
||||
)])
|
||||
.await;
|
||||
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
|
||||
assert!(matches!(
|
||||
events.last(),
|
||||
Some(CompletionEvent::Finish { reason: Some(r) }) if r == "length"
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn function_call_items_emit_toolcall_events() {
|
||||
let events = collect_events(vec![
|
||||
sse_event(
|
||||
"response.output_item.added",
|
||||
r#"{"output_index":0,"item":{"type":"function_call","id":"item_1","call_id":"call_xyz","name":"read_file"}}"#,
|
||||
),
|
||||
sse_event(
|
||||
"response.function_call_arguments.delta",
|
||||
r#"{"item_id":"item_1","output_index":0,"delta":"{\"path"}"#,
|
||||
),
|
||||
sse_event(
|
||||
"response.function_call_arguments.delta",
|
||||
r#"{"item_id":"item_1","output_index":0,"delta":"\":\"/etc/hostname\"}"}"#,
|
||||
),
|
||||
sse_event("response.completed", r#"{"response":{"status":"completed"}}"#),
|
||||
])
|
||||
.await;
|
||||
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
|
||||
let mut iter = events.into_iter();
|
||||
assert!(matches!(
|
||||
iter.next(),
|
||||
Some(CompletionEvent::ToolCallStart { index: 0, ref id, ref name })
|
||||
if id == "call_xyz" && name == "read_file"
|
||||
));
|
||||
assert!(matches!(
|
||||
iter.next(),
|
||||
Some(CompletionEvent::ToolCallArgsDelta { index: 0, ref args_delta })
|
||||
if args_delta == r#"{"path"#
|
||||
));
|
||||
assert!(matches!(
|
||||
iter.next(),
|
||||
Some(CompletionEvent::ToolCallArgsDelta { index: 0, ref args_delta })
|
||||
if args_delta == r#"":"/etc/hostname"}"#
|
||||
));
|
||||
assert!(matches!(iter.next(), Some(CompletionEvent::Finish { .. })));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn function_call_added_with_inline_arguments_emits_single_args_delta() {
|
||||
// Some upstreams (rare) include the fully-buffered arguments
|
||||
// on the `output_item.added` event when the model finalised
|
||||
// the call before SSE flush. Verify both ToolCallStart and a
|
||||
// single args delta fire.
|
||||
let events = collect_events(vec![
|
||||
sse_event(
|
||||
"response.output_item.added",
|
||||
r#"{"output_index":0,"item":{"type":"function_call","call_id":"call_a","name":"f","arguments":"{\"x\":1}"}}"#,
|
||||
),
|
||||
sse_event("response.completed", r#"{"response":{"status":"completed"}}"#),
|
||||
])
|
||||
.await;
|
||||
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
|
||||
let mut iter = events.into_iter();
|
||||
assert!(matches!(
|
||||
iter.next(),
|
||||
Some(CompletionEvent::ToolCallStart { .. })
|
||||
));
|
||||
assert!(matches!(
|
||||
iter.next(),
|
||||
Some(CompletionEvent::ToolCallArgsDelta { index: 0, ref args_delta })
|
||||
if args_delta == r#"{"x":1}"#
|
||||
));
|
||||
assert!(matches!(iter.next(), Some(CompletionEvent::Finish { .. })));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancellation_ends_stream_promptly() {
|
||||
// Hand the decoder an empty stream + a triggered cancellation
|
||||
// token; it should terminate without yielding anything.
|
||||
let sse = stream::iter(Vec::<
|
||||
Result<eventsource_stream::Event, eventsource_stream::EventStreamError<reqwest::Error>>,
|
||||
>::new());
|
||||
let cancel = CancellationToken::new();
|
||||
cancel.cancel();
|
||||
let decoded = decode_stream(sse, cancel);
|
||||
let events: Vec<_> = decoded.collect().await;
|
||||
assert!(events.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn malformed_event_payload_is_skipped() {
|
||||
let events = collect_events(vec![
|
||||
sse_event("response.output_text.delta", "{not valid json"),
|
||||
sse_event(
|
||||
"response.output_text.delta",
|
||||
r#"{"item_id":"m","output_index":0,"delta":"ok"}"#,
|
||||
),
|
||||
sse_event(
|
||||
"response.completed",
|
||||
r#"{"response":{"status":"completed"}}"#,
|
||||
),
|
||||
])
|
||||
.await;
|
||||
let events: Vec<CompletionEvent> = events.into_iter().map(|r| r.unwrap()).collect();
|
||||
// First text delta dropped; second one fires.
|
||||
assert!(
|
||||
events
|
||||
.iter()
|
||||
.any(|e| matches!(e, CompletionEvent::TextDelta(t) if t == "ok"))
|
||||
);
|
||||
// No errors yielded (parse failures are warn-and-skip).
|
||||
assert!(
|
||||
events
|
||||
.iter()
|
||||
.all(|e| !matches!(e, CompletionEvent::Finish { reason: None }))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_construction_is_cheap() {
|
||||
let _ = OpenAIResponsesProvider::new(ep()).unwrap();
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,188 +0,0 @@
|
||||
//! Per-session state for the ACP agent loop.
|
||||
//!
|
||||
//! Concurrency:
|
||||
//!
|
||||
//! - [`SessionStore`] is an `Arc<RwLock<HashMap<SessionId, …>>>`. The map
|
||||
//! itself is read-mostly: it changes only on `session/new` and never
|
||||
//! shrinks during Stage 2, so an `RwLock` keeps concurrent reads
|
||||
//! contention-free.
|
||||
//! - Each session is wrapped in its own `Arc<Mutex<SessionState>>`. Holding
|
||||
//! one session's lock doesn't block requests against any other session,
|
||||
//! which matters once a client opens multiple sessions in parallel.
|
||||
//!
|
||||
//! All operations hold a lock only long enough to copy out (or mutate) the
|
||||
//! state they need — never across an `await` that drives the upstream
|
||||
//! provider stream.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use agent_client_protocol::schema::{SessionId, SessionModeId};
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::provider::Message;
|
||||
|
||||
/// Mode id advertised as the gated default. Writes / bash prompt for
|
||||
/// permission via `session/request_permission`.
|
||||
pub const MODE_DEFAULT: &str = "default";
|
||||
|
||||
/// Mode id advertised as "auto-allow everything". Matches the
|
||||
/// favorite name (`bypassPermissions`) Zed clients tend to reference.
|
||||
pub const MODE_BYPASS: &str = "bypassPermissions";
|
||||
|
||||
/// Mode id for read-and-plan-only operation. The model may read files
|
||||
/// and list directories freely, may write *only* into the per-project
|
||||
/// plan directory under `$XDG_DATA_HOME/helexa-acp/plans/<project-id>/`,
|
||||
/// and cannot run shell commands. Designed for "draft the
|
||||
/// implementation plan, then I'll review and let you execute" flows.
|
||||
pub const MODE_PLAN: &str = "plan";
|
||||
|
||||
/// State carried for a single ACP session.
|
||||
///
|
||||
/// Mutated under `Mutex<SessionState>`; never share a clone across
|
||||
/// tasks expecting to see the same `cancel` token — clone the token
|
||||
/// explicitly when handing it to the streaming task.
|
||||
#[derive(Debug)]
|
||||
pub struct SessionState {
|
||||
/// Conversation history in chronological order (user / assistant
|
||||
/// turns). The system prompt is *not* stored here — it's built
|
||||
/// fresh per request so any cwd / config changes take effect.
|
||||
pub history: Vec<Message>,
|
||||
/// Working directory the client opened the session against. Used
|
||||
/// by [`crate::prompt::build_system_prompt`] and (Stage 3) by
|
||||
/// filesystem tools.
|
||||
pub cwd: PathBuf,
|
||||
/// Currently-selected model id. Format is either a bare model id
|
||||
/// (resolved against the default endpoint) or `endpoint:model`.
|
||||
/// Mutated by `session/set_model` in Stage 4; Stage 2 sets it
|
||||
/// once at session creation and never changes it.
|
||||
pub model_id: String,
|
||||
/// Cancellation handle for the in-flight prompt, if any. A fresh
|
||||
/// token is installed at the start of every `session/prompt`
|
||||
/// request; `session/cancel` fires this one. Between prompts the
|
||||
/// token is "spent" — firing it does nothing — which is fine,
|
||||
/// `session/cancel` is a no-op when there's nothing to cancel.
|
||||
pub cancel: CancellationToken,
|
||||
/// Permission gating mode. Stage 3 advertises two ids in
|
||||
/// `NewSessionResponse.modes`: [`MODE_DEFAULT`] (writes / bash
|
||||
/// prompt the user) and [`MODE_BYPASS`] (auto-allow). Mutated by
|
||||
/// `session/set_mode`.
|
||||
pub mode_id: SessionModeId,
|
||||
}
|
||||
|
||||
impl SessionState {
|
||||
pub fn new(cwd: PathBuf, model_id: String) -> Self {
|
||||
Self {
|
||||
history: Vec::new(),
|
||||
cwd,
|
||||
model_id,
|
||||
cancel: CancellationToken::new(),
|
||||
mode_id: SessionModeId::new(MODE_DEFAULT),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Concurrent map of live sessions.
|
||||
///
|
||||
/// Cloning is cheap (`Arc` bump). Pass clones into every handler that
|
||||
/// needs session access; never hold a clone across an `.await` that
|
||||
/// could outlive the request.
|
||||
pub type SessionStore = Arc<RwLock<HashMap<SessionId, Arc<Mutex<SessionState>>>>>;
|
||||
|
||||
/// Fresh, empty session store.
|
||||
pub fn new_store() -> SessionStore {
|
||||
Arc::new(RwLock::new(HashMap::new()))
|
||||
}
|
||||
|
||||
/// Look up a session by id. Returns `None` if no such session is registered.
|
||||
pub async fn get(store: &SessionStore, id: &SessionId) -> Option<Arc<Mutex<SessionState>>> {
|
||||
store.read().await.get(id).cloned()
|
||||
}
|
||||
|
||||
/// Register a fresh session. Overwrites any prior entry with the same id
|
||||
/// (which should never happen — ids are uniquely generated by the agent).
|
||||
pub async fn insert(store: &SessionStore, id: SessionId, state: SessionState) {
|
||||
store.write().await.insert(id, Arc::new(Mutex::new(state)));
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::{MessageContent, Role};
|
||||
|
||||
fn id(s: &str) -> SessionId {
|
||||
SessionId::new(s)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn insert_then_get_round_trip() {
|
||||
let store = new_store();
|
||||
let state = SessionState::new(PathBuf::from("/tmp"), "m".into());
|
||||
insert(&store, id("s1"), state).await;
|
||||
let got = get(&store, &id("s1")).await.expect("session present");
|
||||
let locked = got.lock().await;
|
||||
assert_eq!(locked.cwd, PathBuf::from("/tmp"));
|
||||
assert_eq!(locked.model_id, "m");
|
||||
assert!(locked.history.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_session_is_none() {
|
||||
let store = new_store();
|
||||
assert!(get(&store, &id("nope")).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_is_per_session() {
|
||||
let store = new_store();
|
||||
insert(
|
||||
&store,
|
||||
id("a"),
|
||||
SessionState::new(PathBuf::from("/a"), "m".into()),
|
||||
)
|
||||
.await;
|
||||
insert(
|
||||
&store,
|
||||
id("b"),
|
||||
SessionState::new(PathBuf::from("/b"), "m".into()),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Appending to a's history must not affect b's.
|
||||
get(&store, &id("a"))
|
||||
.await
|
||||
.unwrap()
|
||||
.lock()
|
||||
.await
|
||||
.history
|
||||
.push(Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text {
|
||||
text: "hello".into(),
|
||||
},
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
get(&store, &id("a"))
|
||||
.await
|
||||
.unwrap()
|
||||
.lock()
|
||||
.await
|
||||
.history
|
||||
.len(),
|
||||
1
|
||||
);
|
||||
assert_eq!(
|
||||
get(&store, &id("b"))
|
||||
.await
|
||||
.unwrap()
|
||||
.lock()
|
||||
.await
|
||||
.history
|
||||
.len(),
|
||||
0
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,462 +0,0 @@
|
||||
//! On-disk session persistence for `session/load` support.
|
||||
//!
|
||||
//! Storage layout:
|
||||
//!
|
||||
//! ```text
|
||||
//! $XDG_DATA_HOME/helexa-acp/sessions/{session_id}.json
|
||||
//! ```
|
||||
//!
|
||||
//! (Fallback to `~/.local/share/helexa-acp/sessions/` when
|
||||
//! `$XDG_DATA_HOME` is unset.) One JSON file per session. Writes
|
||||
//! happen at the end of every `session/prompt` round through
|
||||
//! [`save`], using tempfile-plus-rename so a crash mid-write can't
|
||||
//! corrupt the store. Reads happen on `session/load` via [`load`].
|
||||
//!
|
||||
//! No compaction, no rotation: files accumulate until the user
|
||||
//! cleans them up. That's deliberate — disk is cheap, and the
|
||||
//! resume-on-restart workflow matters more than tidiness. The
|
||||
//! [`SESSIONS_DIRNAME`] subdirectory is created lazily on first
|
||||
//! save so an unprivileged install path never errors at startup.
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::time::SystemTime;
|
||||
|
||||
use agent_client_protocol::schema::SessionId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::provider::Message;
|
||||
|
||||
const APP_DIRNAME: &str = "helexa-acp";
|
||||
const SESSIONS_DIRNAME: &str = "sessions";
|
||||
const PLANS_DIRNAME: &str = "plans";
|
||||
|
||||
/// The shape persisted to disk for one session. Only what we can't
|
||||
/// rebuild from the running config goes in here: the conversation
|
||||
/// history, the mode toggle, the model id, and the cwd-at-creation.
|
||||
///
|
||||
/// `created_at` / `updated_at` are seconds-since-epoch — cheap to
|
||||
/// compare, no third-party time crate, and stable across runs.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PersistedSession {
|
||||
pub session_id: String,
|
||||
pub cwd: PathBuf,
|
||||
pub model_id: String,
|
||||
pub mode_id: String,
|
||||
pub history: Vec<Message>,
|
||||
pub created_at: u64,
|
||||
pub updated_at: u64,
|
||||
}
|
||||
|
||||
/// Resolve the directory that holds session JSON files. Honors
|
||||
/// `$XDG_DATA_HOME`; falls back to `~/.local/share/helexa-acp/sessions/`.
|
||||
/// Returns `None` if neither is resolvable (no `HOME` set — possible
|
||||
/// in stripped-down container environments).
|
||||
pub fn sessions_dir() -> Option<PathBuf> {
|
||||
let base = std::env::var("XDG_DATA_HOME")
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(PathBuf::from)
|
||||
.or_else(|| {
|
||||
std::env::var("HOME")
|
||||
.ok()
|
||||
.map(|h| PathBuf::from(h).join(".local").join("share"))
|
||||
})?;
|
||||
Some(base.join(APP_DIRNAME).join(SESSIONS_DIRNAME))
|
||||
}
|
||||
|
||||
/// Atomic save into the default sessions directory.
|
||||
pub fn save(session: &PersistedSession) -> anyhow::Result<()> {
|
||||
let dir = sessions_dir()
|
||||
.ok_or_else(|| anyhow::anyhow!("can't resolve XDG_DATA_HOME or HOME for session store"))?;
|
||||
save_to_dir(&dir, session)
|
||||
}
|
||||
|
||||
/// Load from the default sessions directory.
|
||||
pub fn load(session_id: &SessionId) -> anyhow::Result<PersistedSession> {
|
||||
let dir = sessions_dir()
|
||||
.ok_or_else(|| anyhow::anyhow!("can't resolve XDG_DATA_HOME or HOME for session store"))?;
|
||||
load_from_dir(&dir, session_id)
|
||||
}
|
||||
|
||||
/// Atomic save into an explicit directory. Writes to
|
||||
/// `{id}.json.tmp` then renames over `{id}.json`. Creates the
|
||||
/// target directory if it doesn't exist. Split from [`save`] so
|
||||
/// unit tests can target a per-test scratch dir without mutating
|
||||
/// process-global env vars.
|
||||
pub fn save_to_dir(dir: &std::path::Path, session: &PersistedSession) -> anyhow::Result<()> {
|
||||
std::fs::create_dir_all(dir).map_err(|e| anyhow::anyhow!("create {}: {e}", dir.display()))?;
|
||||
let safe = sanitize_id(&session.session_id);
|
||||
let final_path = dir.join(format!("{safe}.json"));
|
||||
let tmp_path = dir.join(format!("{safe}.json.tmp"));
|
||||
let json = serde_json::to_string_pretty(session)?;
|
||||
std::fs::write(&tmp_path, json)
|
||||
.map_err(|e| anyhow::anyhow!("write {}: {e}", tmp_path.display()))?;
|
||||
std::fs::rename(&tmp_path, &final_path)
|
||||
.map_err(|e| anyhow::anyhow!("rename → {}: {e}", final_path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load from an explicit directory. Returns a friendly error
|
||||
/// message when the session id has no file on disk so the caller
|
||||
/// can map it to a clean ACP error response.
|
||||
pub fn load_from_dir(
|
||||
dir: &std::path::Path,
|
||||
session_id: &SessionId,
|
||||
) -> anyhow::Result<PersistedSession> {
|
||||
let safe = sanitize_id(session_id.0.as_ref());
|
||||
let path = dir.join(format!("{safe}.json"));
|
||||
let bytes = std::fs::read(&path).map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
anyhow::anyhow!("no persisted session at {}", path.display())
|
||||
} else {
|
||||
anyhow::anyhow!("read {}: {e}", path.display())
|
||||
}
|
||||
})?;
|
||||
let session: PersistedSession = serde_json::from_slice(&bytes)
|
||||
.map_err(|e| anyhow::anyhow!("parse {}: {e}", path.display()))?;
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
/// List all persisted sessions, optionally filtered by `cwd`. Used
|
||||
/// by the `session/list` handler so a client (Zed) can find the
|
||||
/// session that belongs to the workspace it's reopening.
|
||||
///
|
||||
/// `filter_cwd = None` returns every session on disk. `Some(path)`
|
||||
/// returns only sessions whose persisted `cwd` is exactly equal.
|
||||
///
|
||||
/// Files that fail to parse are skipped with a warning rather than
|
||||
/// aborting the whole list — one corrupt session shouldn't make
|
||||
/// the resume picker unusable.
|
||||
pub fn list(filter_cwd: Option<&std::path::Path>) -> anyhow::Result<Vec<PersistedSession>> {
|
||||
let dir = sessions_dir()
|
||||
.ok_or_else(|| anyhow::anyhow!("can't resolve XDG_DATA_HOME or HOME for session store"))?;
|
||||
list_in_dir(&dir, filter_cwd)
|
||||
}
|
||||
|
||||
/// Explicit-dir variant for tests, mirroring [`save_to_dir`] /
|
||||
/// [`load_from_dir`].
|
||||
pub fn list_in_dir(
|
||||
dir: &std::path::Path,
|
||||
filter_cwd: Option<&std::path::Path>,
|
||||
) -> anyhow::Result<Vec<PersistedSession>> {
|
||||
let read = match std::fs::read_dir(dir) {
|
||||
Ok(r) => r,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(Vec::new()),
|
||||
Err(e) => return Err(anyhow::anyhow!("read_dir {}: {e}", dir.display())),
|
||||
};
|
||||
let mut out = Vec::new();
|
||||
for entry in read.flatten() {
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|s| s.to_str()) != Some("json") {
|
||||
continue;
|
||||
}
|
||||
match std::fs::read(&path).and_then(|bytes| {
|
||||
serde_json::from_slice::<PersistedSession>(&bytes).map_err(std::io::Error::other)
|
||||
}) {
|
||||
Ok(session) => {
|
||||
if let Some(want) = filter_cwd
|
||||
&& session.cwd != want
|
||||
{
|
||||
continue;
|
||||
}
|
||||
out.push(session);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
path = %path.display(),
|
||||
error = %e,
|
||||
"store: skipping unparseable session file"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Most-recent first by updated_at.
|
||||
out.sort_by_key(|s| std::cmp::Reverse(s.updated_at));
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Seconds-since-epoch, saturating to 0 if the system clock is
|
||||
/// behind epoch (which shouldn't happen but the type system
|
||||
/// requires a fallible read).
|
||||
pub fn now_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Root directory for plan-mode artefacts. Mirrors [`sessions_dir`]
|
||||
/// but under `…/helexa-acp/plans/` so plans and conversation
|
||||
/// transcripts are siblings, not nested.
|
||||
pub fn plans_root() -> Option<PathBuf> {
|
||||
sessions_dir().and_then(|s| s.parent().map(|p| p.join(PLANS_DIRNAME)))
|
||||
}
|
||||
|
||||
/// Per-project plan directory:
|
||||
/// `$XDG_DATA_HOME/helexa-acp/plans/<project-id>/`. The id derives
|
||||
/// from the session's cwd so plans for the same project survive
|
||||
/// across cwd-changes (a `/home/foo/git/bar` ↔ symlinked
|
||||
/// `/srv/checkout/bar` would technically diverge, accepted as a
|
||||
/// won't-fix corner case).
|
||||
pub fn plan_dir_for(cwd: &std::path::Path) -> Option<PathBuf> {
|
||||
plans_root().map(|root| root.join(project_id_for(cwd)))
|
||||
}
|
||||
|
||||
/// Deterministic, human-readable project identifier. Format:
|
||||
/// `<basename>-<8-hex>` where the 8-hex suffix is FNV-1a of the
|
||||
/// full path. Basename keeps the path skim-readable when poking
|
||||
/// around `$XDG_DATA_HOME` by hand; the hash suffix disambiguates
|
||||
/// repos that share a final path component (e.g. multiple
|
||||
/// `/.../checkout/beat` checkouts).
|
||||
///
|
||||
/// FNV-1a rather than `std::collections::hash::DefaultHasher`
|
||||
/// because the latter (SipHash) reseeds per process, so it'd give
|
||||
/// us a different project_id on every run.
|
||||
pub fn project_id_for(cwd: &std::path::Path) -> String {
|
||||
let basename = cwd
|
||||
.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or("unknown");
|
||||
let sanitised: String = basename
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_ascii_alphanumeric() || c == '-' || c == '_' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let hash = fnv1a_32(cwd.to_string_lossy().as_bytes());
|
||||
format!("{sanitised}-{hash:08x}")
|
||||
}
|
||||
|
||||
/// FNV-1a (32-bit). Deterministic, no third-party crate. Used for
|
||||
/// project ids only — not cryptographic.
|
||||
fn fnv1a_32(bytes: &[u8]) -> u32 {
|
||||
let mut h: u32 = 0x811c_9dc5;
|
||||
for b in bytes {
|
||||
h ^= u32::from(*b);
|
||||
h = h.wrapping_mul(0x0100_0193);
|
||||
}
|
||||
h
|
||||
}
|
||||
|
||||
/// Format seconds-since-epoch as an ISO 8601 / RFC 3339 string
|
||||
/// (`YYYY-MM-DDTHH:MM:SSZ`) for `SessionInfo.updated_at`. Returns
|
||||
/// `None` for values outside the representable range, in which
|
||||
/// case the caller should omit the field.
|
||||
pub fn unix_to_iso8601(secs: u64) -> Option<String> {
|
||||
use chrono::TimeZone;
|
||||
let dt = chrono::Utc.timestamp_opt(secs as i64, 0).single()?;
|
||||
Some(dt.to_rfc3339_opts(chrono::SecondsFormat::Secs, true))
|
||||
}
|
||||
|
||||
/// Strip anything that isn't a safe filename character so a
|
||||
/// mischievous (or just unconventional) session id can't escape
|
||||
/// the sessions directory.
|
||||
fn sanitize_id(id: &str) -> String {
|
||||
id.chars()
|
||||
.map(|c| {
|
||||
if c.is_ascii_alphanumeric() || c == '-' || c == '_' {
|
||||
c
|
||||
} else {
|
||||
'_'
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::{MessageContent, Role};
|
||||
|
||||
/// Unique scratch dir per test invocation. We use this dir
|
||||
/// directly with the `*_to_dir` / `*_from_dir` functions so
|
||||
/// the tests never mutate `$XDG_DATA_HOME` — that env var
|
||||
/// would race across the parallel test harness.
|
||||
fn unique_dir() -> PathBuf {
|
||||
let base = std::env::var("CARGO_TARGET_TMPDIR")
|
||||
.ok()
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(std::env::temp_dir);
|
||||
let pid = std::process::id();
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.map(|d| d.subsec_nanos())
|
||||
.unwrap_or(0);
|
||||
let dir = base.join(format!("helexa-acp-store-test-{pid}-{nanos}"));
|
||||
std::fs::create_dir_all(&dir).expect("create test dir");
|
||||
dir
|
||||
}
|
||||
|
||||
fn sample(id: &str) -> PersistedSession {
|
||||
PersistedSession {
|
||||
session_id: id.into(),
|
||||
cwd: PathBuf::from("/home/me/proj"),
|
||||
model_id: "Qwen/Qwen3.6-27B".into(),
|
||||
mode_id: "default".into(),
|
||||
history: vec![
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text {
|
||||
text: "hello".into(),
|
||||
},
|
||||
},
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::Text { text: "hi".into() },
|
||||
},
|
||||
],
|
||||
created_at: 1_700_000_000,
|
||||
updated_at: 1_700_000_001,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_trip_save_then_load() {
|
||||
let dir = unique_dir();
|
||||
save_to_dir(&dir, &sample("hxa-1")).expect("save");
|
||||
let loaded = load_from_dir(&dir, &SessionId::new("hxa-1")).expect("load");
|
||||
assert_eq!(loaded.session_id, "hxa-1");
|
||||
assert_eq!(loaded.cwd, PathBuf::from("/home/me/proj"));
|
||||
assert_eq!(loaded.history.len(), 2);
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_missing_session_errors_with_not_found_message() {
|
||||
let dir = unique_dir();
|
||||
let err = load_from_dir(&dir, &SessionId::new("nope")).unwrap_err();
|
||||
let msg = format!("{err}");
|
||||
assert!(
|
||||
msg.contains("no persisted session"),
|
||||
"want NotFound, got: {msg}"
|
||||
);
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_overwrites_existing_atomically() {
|
||||
let dir = unique_dir();
|
||||
save_to_dir(&dir, &sample("hxa-1")).expect("save");
|
||||
let mut updated = sample("hxa-1");
|
||||
updated.history.push(Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text {
|
||||
text: "third turn".into(),
|
||||
},
|
||||
});
|
||||
updated.updated_at = 1_700_000_500;
|
||||
save_to_dir(&dir, &updated).expect("re-save");
|
||||
let loaded = load_from_dir(&dir, &SessionId::new("hxa-1")).expect("load");
|
||||
assert_eq!(loaded.history.len(), 3);
|
||||
assert_eq!(loaded.updated_at, 1_700_000_500);
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_then_load_preserves_tool_calls_and_results() {
|
||||
use crate::provider::ToolCall;
|
||||
let dir = unique_dir();
|
||||
let mut session = sample("hxa-2");
|
||||
session.history.push(Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::ToolCalls {
|
||||
text: Some("calling".into()),
|
||||
calls: vec![ToolCall {
|
||||
id: "call_0".into(),
|
||||
name: "read_file".into(),
|
||||
arguments: r#"{"path":"/etc/hostname"}"#.into(),
|
||||
}],
|
||||
},
|
||||
});
|
||||
session.history.push(Message {
|
||||
role: Role::Tool,
|
||||
content: MessageContent::ToolResult {
|
||||
tool_call_id: "call_0".into(),
|
||||
content: "host".into(),
|
||||
},
|
||||
});
|
||||
save_to_dir(&dir, &session).expect("save");
|
||||
let loaded = load_from_dir(&dir, &SessionId::new("hxa-2")).expect("load");
|
||||
assert_eq!(loaded.history.len(), 4);
|
||||
match &loaded.history[2].content {
|
||||
MessageContent::ToolCalls { calls, .. } => {
|
||||
assert_eq!(calls[0].name, "read_file");
|
||||
}
|
||||
other => panic!("expected ToolCalls, got {other:?}"),
|
||||
}
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_filters_by_cwd_and_sorts_recent_first() {
|
||||
let dir = unique_dir();
|
||||
let mut a = sample("a");
|
||||
a.cwd = PathBuf::from("/home/me/proj-x");
|
||||
a.updated_at = 1_700_000_010;
|
||||
let mut b = sample("b");
|
||||
b.cwd = PathBuf::from("/home/me/proj-x");
|
||||
b.updated_at = 1_700_000_020;
|
||||
let mut c = sample("c");
|
||||
c.cwd = PathBuf::from("/home/me/elsewhere");
|
||||
c.updated_at = 1_700_000_030;
|
||||
save_to_dir(&dir, &a).unwrap();
|
||||
save_to_dir(&dir, &b).unwrap();
|
||||
save_to_dir(&dir, &c).unwrap();
|
||||
|
||||
let proj_x = PathBuf::from("/home/me/proj-x");
|
||||
let list = list_in_dir(&dir, Some(&proj_x)).unwrap();
|
||||
let ids: Vec<&str> = list.iter().map(|s| s.session_id.as_str()).collect();
|
||||
// Filtered to proj-x; b before a because b is more recent.
|
||||
assert_eq!(ids, vec!["b", "a"]);
|
||||
|
||||
let all = list_in_dir(&dir, None).unwrap();
|
||||
assert_eq!(all.len(), 3);
|
||||
// Global list still sorted recent-first across all cwds.
|
||||
assert_eq!(all[0].session_id, "c");
|
||||
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_returns_empty_for_missing_dir() {
|
||||
let dir = unique_dir().join("does-not-exist");
|
||||
let list = list_in_dir(&dir, None).unwrap();
|
||||
assert!(list.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_skips_unparseable_files() {
|
||||
let dir = unique_dir();
|
||||
save_to_dir(&dir, &sample("good")).unwrap();
|
||||
std::fs::write(dir.join("garbage.json"), b"{not valid json").unwrap();
|
||||
let list = list_in_dir(&dir, None).unwrap();
|
||||
// Garbage skipped; good survives.
|
||||
assert_eq!(list.len(), 1);
|
||||
assert_eq!(list[0].session_id, "good");
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn iso8601_formats_unix_seconds() {
|
||||
// 2024-01-01T00:00:00Z is 1704067200 unix seconds.
|
||||
assert_eq!(
|
||||
unix_to_iso8601(1_704_067_200),
|
||||
Some("2024-01-01T00:00:00Z".into())
|
||||
);
|
||||
assert_eq!(unix_to_iso8601(0), Some("1970-01-01T00:00:00Z".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_id_rejects_path_traversal() {
|
||||
// `../../etc/passwd` — 6 non-alnum chars before "etc"
|
||||
// (`.`, `.`, `/`, `.`, `.`, `/`), one between, none
|
||||
// after, none before nothing. Every disallowed char
|
||||
// collapses to `_`.
|
||||
assert_eq!(sanitize_id("../../etc/passwd"), "______etc_passwd");
|
||||
assert_eq!(sanitize_id("ok-name_42"), "ok-name_42");
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,300 +0,0 @@
|
||||
//! Tool schemas sent to the upstream model on every completion.
|
||||
//!
|
||||
//! These are the OpenAI-function-style declarations the LLM sees in
|
||||
//! `CompletionRequest.tools`; the runtime dispatch happens in
|
||||
//! [`crate::tool_runner`]. Keeping declarations and execution in
|
||||
//! separate modules makes it easy to add a tool without touching the
|
||||
//! runner, and vice versa.
|
||||
//!
|
||||
//! Stage 3 ships five: filesystem read / write / edit, directory
|
||||
//! listing, and `bash`. Image generation, web fetch, MCP-derived
|
||||
//! tools, etc. are out of scope here.
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
use crate::provider::ToolSpec;
|
||||
|
||||
pub const READ_FILE: &str = "read_file";
|
||||
pub const WRITE_FILE: &str = "write_file";
|
||||
pub const EDIT_FILE: &str = "edit_file";
|
||||
pub const LIST_DIR: &str = "list_dir";
|
||||
pub const BASH: &str = "bash";
|
||||
|
||||
/// Build the static tool list passed to the model on every prompt.
|
||||
/// Cheap — the JSON Schema fragments are constructed each call but
|
||||
/// the bodies are small constants. If this ever shows up in a
|
||||
/// profile we can `OnceLock` the Vec.
|
||||
pub fn all_tools() -> Vec<ToolSpec> {
|
||||
vec![
|
||||
ToolSpec {
|
||||
name: READ_FILE.to_string(),
|
||||
description: "Read the contents of a text file. Returns the file's text.".to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file."
|
||||
},
|
||||
"line": {
|
||||
"type": "integer",
|
||||
"description": "Optional 1-based line number to start reading from.",
|
||||
"minimum": 1
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Optional maximum number of lines to read.",
|
||||
"minimum": 1
|
||||
}
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
},
|
||||
ToolSpec {
|
||||
name: WRITE_FILE.to_string(),
|
||||
description: "Write text content to a file, replacing any existing contents. \
|
||||
Creates the file (and parent directories) if needed."
|
||||
.to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file."
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Full new contents of the file."
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
},
|
||||
ToolSpec {
|
||||
name: EDIT_FILE.to_string(),
|
||||
description: "Replace one exact substring in a file with another. \
|
||||
Fails if `old_text` does not appear in the file, or appears more than once. \
|
||||
Use multiple edit_file calls for multiple edits."
|
||||
.to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file."
|
||||
},
|
||||
"old_text": {
|
||||
"type": "string",
|
||||
"description": "Exact text fragment to replace. Must be unique within the file."
|
||||
},
|
||||
"new_text": {
|
||||
"type": "string",
|
||||
"description": "Replacement text."
|
||||
}
|
||||
},
|
||||
"required": ["path", "old_text", "new_text"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
},
|
||||
ToolSpec {
|
||||
name: LIST_DIR.to_string(),
|
||||
description:
|
||||
"List the entries of a directory. Returns names and a (f|d|l) kind per entry."
|
||||
.to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the directory."
|
||||
}
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
},
|
||||
ToolSpec {
|
||||
name: BASH.to_string(),
|
||||
description: "Run a shell command via `sh -c`. \
|
||||
Returns combined stdout+stderr and the exit status. \
|
||||
The command runs in the session's working directory unless `cwd` is given."
|
||||
.to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Shell command line, evaluated by `sh -c`."
|
||||
},
|
||||
"cwd": {
|
||||
"type": "string",
|
||||
"description": "Optional absolute path to run the command from."
|
||||
}
|
||||
},
|
||||
"required": ["command"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
/// Try to infer which tool was intended from the shape of an
|
||||
/// `arguments` object alone. Used by the agent when the model
|
||||
/// emits a `<tool_call>` whose JSON has the right arguments but a
|
||||
/// missing or invalid top-level `name` field — a recurring
|
||||
/// Qwen3.6-27B failure mode.
|
||||
///
|
||||
/// Returns `Some(name)` only when the argument keys uniquely match
|
||||
/// exactly one tool in the catalogue. Ambiguous shapes (`{path}`
|
||||
/// alone could be either [`READ_FILE`] or [`LIST_DIR`]) return
|
||||
/// `None` so the caller surfaces a Failed-card and lets the model
|
||||
/// retry rather than guessing wrong.
|
||||
///
|
||||
/// Inference table (key set → tool):
|
||||
///
|
||||
/// | Keys | Tool |
|
||||
/// |---------------------------------------|--------------|
|
||||
/// | `{command}` or `{command, cwd}` | `bash` |
|
||||
/// | `{path, content}` | `write_file` |
|
||||
/// | `{path, old_text, new_text}` | `edit_file` |
|
||||
/// | `{path}` / `{path, line}` / `{path, line, limit}` | *ambiguous* — None |
|
||||
/// | (anything else) | None |
|
||||
pub fn infer_tool_name(arguments: &serde_json::Value) -> Option<&'static str> {
|
||||
let obj = arguments.as_object()?;
|
||||
let keys: std::collections::HashSet<&str> = obj.keys().map(|s| s.as_str()).collect();
|
||||
|
||||
// `command` is unique to bash. Allow the optional `cwd` arg
|
||||
// alongside but nothing else (any unrecognised keys → bail and
|
||||
// let the model retry rather than misroute).
|
||||
if keys.contains("command") && keys.iter().all(|k| matches!(*k, "command" | "cwd")) {
|
||||
return Some(BASH);
|
||||
}
|
||||
// `content` is unique to write_file.
|
||||
if keys.contains("content") && keys.contains("path") && keys.len() == 2 {
|
||||
return Some(WRITE_FILE);
|
||||
}
|
||||
// `old_text` + `new_text` are unique to edit_file.
|
||||
if keys.contains("old_text")
|
||||
&& keys.contains("new_text")
|
||||
&& keys.contains("path")
|
||||
&& keys.len() == 3
|
||||
{
|
||||
return Some(EDIT_FILE);
|
||||
}
|
||||
// `{path}` / `{path, line}` / `{path, line, limit}` overlap
|
||||
// between read_file (file contents) and list_dir (directory
|
||||
// contents). No safe inference — refuse.
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn all_tools_has_five_named_entries() {
|
||||
let tools = all_tools();
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
|
||||
assert_eq!(
|
||||
names,
|
||||
vec![READ_FILE, WRITE_FILE, EDIT_FILE, LIST_DIR, BASH]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infer_bash_from_command_only() {
|
||||
let args = serde_json::json!({"command": "ls /tmp"});
|
||||
assert_eq!(infer_tool_name(&args), Some(BASH));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infer_bash_from_command_and_cwd() {
|
||||
let args = serde_json::json!({"command": "ls", "cwd": "/tmp"});
|
||||
assert_eq!(infer_tool_name(&args), Some(BASH));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infer_bash_from_mkdir_like_real_failure() {
|
||||
// Lifted verbatim from the agent failure that motivated
|
||||
// this helper (helexa-acp.log @ 10:03:11).
|
||||
let args = serde_json::json!({
|
||||
"command": "mkdir -p /home/grenade/git/beat/beat/doc/plan/{01-discovery,02-segmentation,03-description,04-summary,05-output}"
|
||||
});
|
||||
assert_eq!(infer_tool_name(&args), Some(BASH));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infer_write_file() {
|
||||
let args = serde_json::json!({"path": "/tmp/x", "content": "hi"});
|
||||
assert_eq!(infer_tool_name(&args), Some(WRITE_FILE));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infer_edit_file() {
|
||||
let args = serde_json::json!({
|
||||
"path": "/tmp/x", "old_text": "a", "new_text": "b"
|
||||
});
|
||||
assert_eq!(infer_tool_name(&args), Some(EDIT_FILE));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refuse_ambiguous_path_only() {
|
||||
let args = serde_json::json!({"path": "/tmp/x"});
|
||||
assert_eq!(infer_tool_name(&args), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refuse_ambiguous_path_with_optionals() {
|
||||
// read_file accepts these optionals; list_dir doesn't —
|
||||
// but Qwen wouldn't reliably emit them either, so we
|
||||
// can't use their presence to disambiguate. Refuse.
|
||||
let args = serde_json::json!({"path": "/tmp/x", "line": 1, "limit": 50});
|
||||
assert_eq!(infer_tool_name(&args), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refuse_command_with_extra_unknown_keys() {
|
||||
// Defence in depth: an unrecognised key alongside
|
||||
// `command` means we don't really know what tool the
|
||||
// model wanted; refuse rather than guess.
|
||||
let args = serde_json::json!({"command": "ls", "extra": "?"});
|
||||
assert_eq!(infer_tool_name(&args), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refuse_empty_args() {
|
||||
let args = serde_json::json!({});
|
||||
assert_eq!(infer_tool_name(&args), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refuse_non_object_args() {
|
||||
let args = serde_json::json!("not an object");
|
||||
assert_eq!(infer_tool_name(&args), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn every_tool_has_an_object_parameter_schema() {
|
||||
for tool in all_tools() {
|
||||
let ty = tool.parameters.get("type").and_then(|v| v.as_str());
|
||||
assert_eq!(
|
||||
ty,
|
||||
Some("object"),
|
||||
"tool {} parameters.type must be \"object\"",
|
||||
tool.name
|
||||
);
|
||||
assert!(
|
||||
tool.parameters.get("properties").is_some(),
|
||||
"tool {} missing properties",
|
||||
tool.name
|
||||
);
|
||||
assert!(
|
||||
tool.parameters.get("required").is_some(),
|
||||
"tool {} missing required list",
|
||||
tool.name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
[package]
|
||||
name = "helexa-bench"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "helexa-bench"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
cortex-core = { workspace = true }
|
||||
|
||||
tokio = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
figment = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
eventsource-stream = { workspace = true }
|
||||
|
||||
# SQLite system-of-record. `bundled` compiles SQLite from source so the
|
||||
# binary has no libsqlite3 runtime dependency — matches the project's
|
||||
# single-static-binary packaging.
|
||||
rusqlite = { version = "0.32", features = ["bundled"] }
|
||||
|
||||
[dev-dependencies]
|
||||
axum = { workspace = true }
|
||||
# Jail (isolated cwd + env) for config tests.
|
||||
figment = { workspace = true, features = ["test"] }
|
||||
@@ -1,159 +0,0 @@
|
||||
//! Outbound calls to a benchmark target: build identity, host discovery,
|
||||
//! and warm-model enumeration. Neuron targets use the native neuron API;
|
||||
//! `openai` targets use the OpenAI-compatible surface (preliminary).
|
||||
|
||||
use crate::config::{TargetConfig, TargetKind};
|
||||
use anyhow::{Context, Result};
|
||||
use cortex_core::build_info::BuildInfo;
|
||||
use cortex_core::discovery::DiscoveryResponse;
|
||||
use cortex_core::harness::ModelInfo;
|
||||
use cortex_core::openai::ModelsResponse;
|
||||
use std::time::Duration;
|
||||
|
||||
/// How long to wait on the cheap metadata polls (version/discovery/models).
|
||||
const META_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
pub struct TargetClient {
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
impl TargetClient {
|
||||
pub fn new(request_timeout: Duration) -> Result<Self> {
|
||||
let http = reqwest::Client::builder()
|
||||
.timeout(request_timeout)
|
||||
.build()
|
||||
.context("building HTTP client")?;
|
||||
Ok(TargetClient { http })
|
||||
}
|
||||
|
||||
pub fn http(&self) -> &reqwest::Client {
|
||||
&self.http
|
||||
}
|
||||
|
||||
/// Chat-completions URL for the target.
|
||||
pub fn chat_url(&self, target: &TargetConfig) -> String {
|
||||
let base = target.endpoint.trim_end_matches('/');
|
||||
match target.kind {
|
||||
// neuron exposes OpenAI routes under /v1.
|
||||
TargetKind::Neuron => format!("{base}/v1/chat/completions"),
|
||||
// openai endpoint is the /v1 base already (bench.py convention).
|
||||
TargetKind::Openai => format!("{base}/chat/completions"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build identity. Neuron: `GET /version`. Openai: a synthetic
|
||||
/// placeholder keyed by `"external"` so the version-aware skip logic
|
||||
/// treats it as one stable build (comparison runs are manual anyway).
|
||||
pub async fn fetch_version(&self, target: &TargetConfig) -> Result<BuildInfo> {
|
||||
match target.kind {
|
||||
TargetKind::Neuron => {
|
||||
let base = target.endpoint.trim_end_matches('/');
|
||||
let info = self
|
||||
.http
|
||||
.get(format!("{base}/version"))
|
||||
.timeout(META_TIMEOUT)
|
||||
.send()
|
||||
.await
|
||||
.context("GET /version")?
|
||||
.error_for_status()
|
||||
.context("GET /version status")?
|
||||
.json::<BuildInfo>()
|
||||
.await
|
||||
.context("decoding /version")?;
|
||||
Ok(info)
|
||||
}
|
||||
TargetKind::Openai => {
|
||||
let mut info = BuildInfo::unknown();
|
||||
info.git_sha = "external".to_string();
|
||||
Ok(info)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Host discovery (neuron only).
|
||||
pub async fn fetch_discovery(
|
||||
&self,
|
||||
target: &TargetConfig,
|
||||
) -> Result<Option<DiscoveryResponse>> {
|
||||
if target.kind != TargetKind::Neuron {
|
||||
return Ok(None);
|
||||
}
|
||||
let base = target.endpoint.trim_end_matches('/');
|
||||
let disco = self
|
||||
.http
|
||||
.get(format!("{base}/discovery"))
|
||||
.timeout(META_TIMEOUT)
|
||||
.send()
|
||||
.await
|
||||
.context("GET /discovery")?
|
||||
.error_for_status()
|
||||
.context("GET /discovery status")?
|
||||
.json::<DiscoveryResponse>()
|
||||
.await
|
||||
.context("decoding /discovery")?;
|
||||
Ok(Some(disco))
|
||||
}
|
||||
|
||||
/// Warm models — those ready to serve without a cold load.
|
||||
///
|
||||
/// Neuron: `GET /models` filtered to `status == "loaded"` (skips
|
||||
/// `recovering`/`poisoned`). Openai: `GET /models`, honouring the
|
||||
/// helexa `loaded` extension when present, else treating all listed
|
||||
/// models as warm.
|
||||
pub async fn warm_models(&self, target: &TargetConfig) -> Result<Vec<ModelInfo>> {
|
||||
let base = target.endpoint.trim_end_matches('/');
|
||||
match target.kind {
|
||||
TargetKind::Neuron => {
|
||||
let models = self
|
||||
.http
|
||||
.get(format!("{base}/models"))
|
||||
.timeout(META_TIMEOUT)
|
||||
.send()
|
||||
.await
|
||||
.context("GET /models")?
|
||||
.error_for_status()
|
||||
.context("GET /models status")?
|
||||
.json::<Vec<ModelInfo>>()
|
||||
.await
|
||||
.context("decoding /models")?;
|
||||
Ok(models
|
||||
.into_iter()
|
||||
.filter(|m| m.status == "loaded")
|
||||
.collect())
|
||||
}
|
||||
TargetKind::Openai => {
|
||||
let resp = self
|
||||
.http
|
||||
.get(format!("{base}/models"))
|
||||
.timeout(META_TIMEOUT)
|
||||
.send()
|
||||
.await
|
||||
.context("GET /models")?
|
||||
.error_for_status()
|
||||
.context("GET /models status")?
|
||||
.json::<ModelsResponse>()
|
||||
.await
|
||||
.context("decoding /models")?;
|
||||
Ok(resp
|
||||
.data
|
||||
.into_iter()
|
||||
.filter(|m| {
|
||||
// honour the helexa `loaded` extension if present
|
||||
m.extra
|
||||
.get("loaded")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(true)
|
||||
})
|
||||
.map(|m| ModelInfo {
|
||||
id: m.id,
|
||||
harness: "openai".to_string(),
|
||||
status: "loaded".to_string(),
|
||||
devices: Vec::new(),
|
||||
vram_used_mb: None,
|
||||
capabilities: Vec::new(),
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,210 +0,0 @@
|
||||
//! Bench configuration: loaded from `helexa-bench.toml` with figment,
|
||||
//! `BENCH_`-prefixed env overrides (mirrors `NeuronConfig::load`).
|
||||
|
||||
use figment::{
|
||||
Figment,
|
||||
providers::{Env, Format, Toml},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Top-level bench config.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BenchConfig {
|
||||
#[serde(default)]
|
||||
pub bench: BenchSettings,
|
||||
#[serde(default)]
|
||||
pub scenarios: ScenarioConfig,
|
||||
/// Endpoints to benchmark. At least one is required for `run`/`once`.
|
||||
#[serde(default)]
|
||||
pub targets: Vec<TargetConfig>,
|
||||
}
|
||||
|
||||
/// Loop/timing knobs.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BenchSettings {
|
||||
/// Pause between full sweeps of all targets.
|
||||
#[serde(default = "default_sweep_interval")]
|
||||
pub sweep_interval_secs: u64,
|
||||
/// Target number of measured samples to record for a given
|
||||
/// (target, build SHA, model, scenario). Once met, later sweeps skip
|
||||
/// that cell — so a fully-sampled build costs only cheap version
|
||||
/// polls until a new SHA ships.
|
||||
#[serde(default = "default_samples")]
|
||||
pub samples_per_version: u32,
|
||||
/// Pause between successive measured iterations against one model.
|
||||
#[serde(default = "default_iter_pause")]
|
||||
pub iteration_pause_secs: u64,
|
||||
/// Per-request timeout (cold lazy-loads can be slow; generous like
|
||||
/// bench.py's 600s default).
|
||||
#[serde(default = "default_timeout")]
|
||||
pub request_timeout_secs: u64,
|
||||
/// SQLite system-of-record path.
|
||||
#[serde(default = "default_db_path")]
|
||||
pub db_path: String,
|
||||
}
|
||||
|
||||
impl Default for BenchSettings {
|
||||
fn default() -> Self {
|
||||
BenchSettings {
|
||||
sweep_interval_secs: default_sweep_interval(),
|
||||
samples_per_version: default_samples(),
|
||||
iteration_pause_secs: default_iter_pause(),
|
||||
request_timeout_secs: default_timeout(),
|
||||
db_path: default_db_path(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BenchSettings {
|
||||
pub fn iteration_pause(&self) -> Duration {
|
||||
Duration::from_secs(self.iteration_pause_secs)
|
||||
}
|
||||
pub fn request_timeout(&self) -> Duration {
|
||||
Duration::from_secs(self.request_timeout_secs)
|
||||
}
|
||||
pub fn sweep_interval(&self) -> Duration {
|
||||
Duration::from_secs(self.sweep_interval_secs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Which scenarios to run and their shared parameters.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ScenarioConfig {
|
||||
/// Approximate prompt sizes (in tokens) — one chat-latency scenario
|
||||
/// is generated per size, e.g. `chat:128`, `chat:4096`. This is the
|
||||
/// per-cell dimension that the version-aware skip logic keys on.
|
||||
#[serde(default = "default_prompt_sizes")]
|
||||
pub prompt_sizes: Vec<u32>,
|
||||
/// Max generated tokens per request.
|
||||
#[serde(default = "default_max_tokens")]
|
||||
pub max_tokens: u64,
|
||||
}
|
||||
|
||||
impl Default for ScenarioConfig {
|
||||
fn default() -> Self {
|
||||
ScenarioConfig {
|
||||
prompt_sizes: default_prompt_sizes(),
|
||||
max_tokens: default_max_tokens(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// One endpoint to benchmark.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TargetConfig {
|
||||
/// Stable label used as the engine column and in the DB.
|
||||
pub name: String,
|
||||
/// Which protocol/metadata surface the target exposes.
|
||||
#[serde(default)]
|
||||
pub kind: TargetKind,
|
||||
/// Base URL. For `neuron`: the daemon root (e.g.
|
||||
/// `http://beast.internal:13131`). For `openai`: the OpenAI `/v1`
|
||||
/// base (e.g. `http://host:8080/v1`).
|
||||
pub endpoint: String,
|
||||
/// Optional display label override for reports (defaults to `name`).
|
||||
#[serde(default)]
|
||||
pub label: Option<String>,
|
||||
}
|
||||
|
||||
impl TargetConfig {
|
||||
pub fn display_label(&self) -> &str {
|
||||
self.label.as_deref().unwrap_or(&self.name)
|
||||
}
|
||||
}
|
||||
|
||||
/// The two target surfaces. `neuron` gets rich build metadata and warm
|
||||
/// model discovery via the native neuron API; `openai` is the seam for
|
||||
/// later comparison against mistral.rs / llama.cpp / vLLM (phase 1
|
||||
/// implements `neuron` fully; `openai` is preliminary plumbing).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TargetKind {
|
||||
#[default]
|
||||
Neuron,
|
||||
Openai,
|
||||
}
|
||||
|
||||
impl BenchConfig {
|
||||
pub fn load(path: impl AsRef<Path>) -> Result<Self, Box<figment::Error>> {
|
||||
Figment::new()
|
||||
.merge(Toml::file(path))
|
||||
.merge(Env::prefixed("BENCH_").split("__"))
|
||||
.extract()
|
||||
.map_err(Box::new)
|
||||
}
|
||||
}
|
||||
|
||||
fn default_sweep_interval() -> u64 {
|
||||
1800
|
||||
}
|
||||
fn default_samples() -> u32 {
|
||||
5
|
||||
}
|
||||
fn default_iter_pause() -> u64 {
|
||||
2
|
||||
}
|
||||
fn default_timeout() -> u64 {
|
||||
600
|
||||
}
|
||||
fn default_db_path() -> String {
|
||||
"/var/lib/helexa-bench/bench.sqlite".to_string()
|
||||
}
|
||||
fn default_prompt_sizes() -> Vec<u32> {
|
||||
vec![128, 4096]
|
||||
}
|
||||
fn default_max_tokens() -> u64 {
|
||||
256
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
// Jail's closure must return figment::Result; the large-Err type is
|
||||
// figment's, not ours, so suppress the lint here.
|
||||
#[allow(clippy::result_large_err)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use figment::Jail;
|
||||
|
||||
#[test]
|
||||
fn loads_minimal_with_defaults() {
|
||||
Jail::expect_with(|jail| {
|
||||
jail.create_file(
|
||||
"helexa-bench.toml",
|
||||
r#"
|
||||
[[targets]]
|
||||
name = "beast"
|
||||
endpoint = "http://beast.internal:13131"
|
||||
"#,
|
||||
)?;
|
||||
let cfg = BenchConfig::load("helexa-bench.toml").unwrap();
|
||||
assert_eq!(cfg.targets.len(), 1);
|
||||
assert_eq!(cfg.targets[0].kind, TargetKind::Neuron);
|
||||
assert_eq!(cfg.bench.samples_per_version, 5);
|
||||
assert_eq!(cfg.scenarios.prompt_sizes, vec![128, 4096]);
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_overrides_apply() {
|
||||
Jail::expect_with(|jail| {
|
||||
jail.create_file(
|
||||
"helexa-bench.toml",
|
||||
r#"
|
||||
[bench]
|
||||
samples_per_version = 3
|
||||
[[targets]]
|
||||
name = "benjy"
|
||||
kind = "openai"
|
||||
endpoint = "http://benjy:8080/v1"
|
||||
"#,
|
||||
)?;
|
||||
jail.set_env("BENCH_BENCH__SAMPLES_PER_VERSION", "9");
|
||||
let cfg = BenchConfig::load("helexa-bench.toml").unwrap();
|
||||
assert_eq!(cfg.bench.samples_per_version, 9);
|
||||
assert_eq!(cfg.targets[0].kind, TargetKind::Openai);
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
//! helexa-bench — a continuous, version-aware benchmark harness for the
|
||||
//! neuron fleet. It hits each neuron directly, exercises an extensible
|
||||
//! scenario suite against every warm model, and records each run with
|
||||
//! full build/version provenance into SQLite so improvements can be
|
||||
//! tracked automatically across neuron implementation updates.
|
||||
|
||||
pub mod client;
|
||||
pub mod config;
|
||||
pub mod report;
|
||||
pub mod scenario;
|
||||
pub mod store;
|
||||
pub mod sweep;
|
||||
@@ -1,126 +0,0 @@
|
||||
//! helexa-bench CLI.
|
||||
//!
|
||||
//! - `run` — continuous daemon (systemd default): sweep, sleep, repeat.
|
||||
//! - `once` — a single sweep, then exit (manual / CI).
|
||||
//! - `report` — render the SQLite store as a results table.
|
||||
//!
|
||||
//! Runs on a single-threaded runtime: the workload is batch-1 sequential
|
||||
//! (one request at a time, the regime we measure), and it lets the
|
||||
//! SQLite connection live across awaits without `Sync` gymnastics.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use clap::{Parser, Subcommand};
|
||||
use helexa_bench::config::BenchConfig;
|
||||
use helexa_bench::report;
|
||||
use helexa_bench::store::Store;
|
||||
use helexa_bench::sweep::Sweeper;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "helexa-bench")]
|
||||
#[command(about = "Continuous version-aware benchmark harness for the neuron fleet")]
|
||||
#[command(version)]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Command,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Command {
|
||||
/// Run sweeps continuously, pausing `sweep_interval_secs` between them.
|
||||
Run {
|
||||
#[arg(short, long, default_value = "helexa-bench.toml")]
|
||||
config: String,
|
||||
},
|
||||
/// Run a single sweep over all targets, then exit.
|
||||
Once {
|
||||
#[arg(short, long, default_value = "helexa-bench.toml")]
|
||||
config: String,
|
||||
},
|
||||
/// Render recorded results. Uses `--db` if given, else the db_path
|
||||
/// from `--config`.
|
||||
Report {
|
||||
#[arg(short, long, default_value = "helexa-bench.toml")]
|
||||
config: String,
|
||||
/// Override the SQLite path (skips reading the config file).
|
||||
#[arg(long)]
|
||||
db: Option<String>,
|
||||
/// Output format.
|
||||
#[arg(long, default_value = "md")]
|
||||
format: Format,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, clap::ValueEnum)]
|
||||
enum Format {
|
||||
Md,
|
||||
Json,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
|
||||
)
|
||||
.init();
|
||||
|
||||
let cli = Cli::parse();
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.context("building tokio runtime")?;
|
||||
rt.block_on(run(cli))
|
||||
}
|
||||
|
||||
async fn run(cli: Cli) -> Result<()> {
|
||||
match cli.command {
|
||||
Command::Run { config } => {
|
||||
let cfg = load_config(&config)?;
|
||||
require_targets(&cfg)?;
|
||||
let sweeper = Sweeper::new(cfg)?;
|
||||
tracing::info!("helexa-bench started; entering continuous sweep loop");
|
||||
sweeper.run_forever().await
|
||||
}
|
||||
Command::Once { config } => {
|
||||
let cfg = load_config(&config)?;
|
||||
require_targets(&cfg)?;
|
||||
let sweeper = Sweeper::new(cfg)?;
|
||||
let summary = sweeper.run_once().await?;
|
||||
tracing::info!(
|
||||
measured = summary.measured,
|
||||
skipped = summary.skipped,
|
||||
failed = summary.failed,
|
||||
unreachable = summary.targets_unreachable,
|
||||
"single sweep complete"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
Command::Report { config, db, format } => {
|
||||
let db_path = match db {
|
||||
Some(p) => p,
|
||||
None => load_config(&config)?.bench.db_path,
|
||||
};
|
||||
let store = Store::open(&db_path)?;
|
||||
let rows = store.report_rows()?;
|
||||
let rendered = match format {
|
||||
Format::Md => report::render_markdown(&rows),
|
||||
Format::Json => report::render_json(&rows)?,
|
||||
};
|
||||
println!("{rendered}");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_config(path: &str) -> Result<BenchConfig> {
|
||||
BenchConfig::load(path)
|
||||
.map_err(|e| anyhow::anyhow!("{e}"))
|
||||
.with_context(|| format!("loading config {path}"))
|
||||
}
|
||||
|
||||
fn require_targets(cfg: &BenchConfig) -> Result<()> {
|
||||
if cfg.targets.is_empty() {
|
||||
anyhow::bail!("no targets configured — add at least one [[targets]] entry");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,106 +0,0 @@
|
||||
//! Render the SQLite store as a results table — the automated
|
||||
//! replacement for hand-editing `doc/benchmarks.md`. Columns match that
|
||||
//! doc: engine, model, prompt tok, TTFT (s), decode tok/s, total (s),
|
||||
//! plus the build SHA each cell was measured against.
|
||||
|
||||
use crate::store::ReportRow;
|
||||
use anyhow::Result;
|
||||
|
||||
pub fn render_markdown(rows: &[ReportRow]) -> String {
|
||||
let mut out = String::new();
|
||||
out.push_str(
|
||||
"| engine | model | prompt tok | TTFT (s) | decode tok/s | total (s) | build | n |\n",
|
||||
);
|
||||
out.push_str("|---|---|---:|---:|---:|---:|---|---:|\n");
|
||||
for r in rows {
|
||||
let ptok = r
|
||||
.prompt_tokens
|
||||
.map(|t| t.to_string())
|
||||
.unwrap_or_else(|| format!("~{}", r.prompt_size_approx));
|
||||
out.push_str(&format!(
|
||||
"| {} | {} | {} | {} | {} | {} | `{}` | {} |\n",
|
||||
r.target_name,
|
||||
r.model_id,
|
||||
ptok,
|
||||
fmt_opt(r.ttft_s_median, 3),
|
||||
fmt_opt(r.decode_tps_median, 1),
|
||||
fmt_opt(r.total_s_median, 3),
|
||||
r.git_sha,
|
||||
r.samples,
|
||||
));
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub fn render_json(rows: &[ReportRow]) -> Result<String> {
|
||||
let arr: Vec<serde_json::Value> = rows
|
||||
.iter()
|
||||
.map(|r| {
|
||||
serde_json::json!({
|
||||
"engine": r.target_name,
|
||||
"model": r.model_id,
|
||||
"scenario": r.scenario_id,
|
||||
"prompt_size_approx": r.prompt_size_approx,
|
||||
"prompt_tokens": r.prompt_tokens,
|
||||
"ttft_s_median": r.ttft_s_median,
|
||||
"decode_tps_median": r.decode_tps_median,
|
||||
"total_s_median": r.total_s_median,
|
||||
"git_sha": r.git_sha,
|
||||
"samples": r.samples,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
Ok(serde_json::to_string_pretty(&arr)?)
|
||||
}
|
||||
|
||||
fn fmt_opt(v: Option<f64>, places: usize) -> String {
|
||||
match v {
|
||||
Some(x) => format!("{x:.places$}"),
|
||||
None => "—".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn markdown_has_header_and_row() {
|
||||
let rows = vec![ReportRow {
|
||||
target_name: "beast".into(),
|
||||
model_id: "Qwen/Qwen3.6-27B".into(),
|
||||
scenario_id: "chat:128".into(),
|
||||
prompt_size_approx: 128,
|
||||
git_sha: "30d50d6".into(),
|
||||
prompt_tokens: Some(130),
|
||||
ttft_s_median: Some(0.123),
|
||||
decode_tps_median: Some(45.6),
|
||||
total_s_median: Some(1.234),
|
||||
samples: 5,
|
||||
}];
|
||||
let md = render_markdown(&rows);
|
||||
assert!(md.contains("| engine |"));
|
||||
assert!(md.contains("beast"));
|
||||
assert!(md.contains("`30d50d6`"));
|
||||
assert!(md.contains("0.123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_decode_renders_dash() {
|
||||
let rows = vec![ReportRow {
|
||||
target_name: "benjy".into(),
|
||||
model_id: "m".into(),
|
||||
scenario_id: "chat:128".into(),
|
||||
prompt_size_approx: 128,
|
||||
git_sha: "abc".into(),
|
||||
prompt_tokens: None,
|
||||
ttft_s_median: Some(0.1),
|
||||
decode_tps_median: None,
|
||||
total_s_median: Some(0.5),
|
||||
samples: 1,
|
||||
}];
|
||||
let md = render_markdown(&rows);
|
||||
assert!(md.contains("~128"));
|
||||
assert!(md.contains("—"));
|
||||
}
|
||||
}
|
||||
@@ -1,238 +0,0 @@
|
||||
//! The extensible test suite.
|
||||
//!
|
||||
//! A [`Scenario`] puts one warm model through one shaped request and
|
||||
//! reports operator-felt metrics (TTFT, decode tok/s, total). Phase 1
|
||||
//! ships the chat-latency family ported faithfully from `script/bench.py`;
|
||||
//! the trait is the seam for future families (vision, concurrency,
|
||||
//! long-generation, cold-start) selected per model via [`Scenario::applies_to`].
|
||||
|
||||
use crate::config::ScenarioConfig;
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use cortex_core::harness::ModelInfo;
|
||||
use cortex_core::openai::ChatCompletionChunk;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::StreamExt;
|
||||
use serde_json::json;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// A paragraph of filler re-used to synthesise prompts of a target
|
||||
/// approximate token count (~4 chars/token heuristic — close enough for
|
||||
/// bucketing; real token counts are read back from the usage object).
|
||||
/// Mirrors `script/bench.py::FILLER`.
|
||||
const FILLER: &str = "The quick brown fox jumps over the lazy dog while the band plays \
|
||||
a slow waltz in the background and somebody counts the beats. ";
|
||||
|
||||
/// `/no_think`: Qwen3-family soft switch keeping thinking models from
|
||||
/// burning the token budget invisibly. Harmless for non-thinking models.
|
||||
const QUESTION: &str = "\n\nRetell the scene above as a vivid story of about 300 words. /no_think";
|
||||
|
||||
/// Build a synthetic prompt of approximately `approx_tokens` tokens.
|
||||
/// Ported from `bench.py::build_prompt`.
|
||||
pub fn build_prompt(approx_tokens: u32) -> String {
|
||||
let target_chars = (approx_tokens.max(16) as usize) * 4;
|
||||
let reps = target_chars / FILLER.len() + 1;
|
||||
let mut body = FILLER.repeat(reps);
|
||||
body.truncate(target_chars);
|
||||
body.push_str(QUESTION);
|
||||
body
|
||||
}
|
||||
|
||||
/// Per-request inputs shared by every scenario.
|
||||
pub struct RunCtx<'a> {
|
||||
pub client: &'a reqwest::Client,
|
||||
/// Fully-qualified chat-completions URL for the target.
|
||||
pub chat_url: String,
|
||||
pub model_id: String,
|
||||
pub max_tokens: u64,
|
||||
pub timeout: Duration,
|
||||
}
|
||||
|
||||
/// Operator-felt metrics for a single measured request.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScenarioMetrics {
|
||||
/// Time to first content chunk (seconds).
|
||||
pub ttft_s: f64,
|
||||
/// Completion tokens / decode window. `None` when the window is too
|
||||
/// short to be honest (≤ 200 ms), matching bench.py.
|
||||
pub decode_tps: Option<f64>,
|
||||
/// Wall-clock for the whole request (seconds).
|
||||
pub total_s: f64,
|
||||
/// Prompt tokens from the final `usage` object, if the server sent one.
|
||||
pub prompt_tokens: Option<u64>,
|
||||
/// Completion tokens: from `usage` when present, else content-chunk count.
|
||||
pub completion_tokens: u64,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Scenario: Send + Sync {
|
||||
/// Stable id, e.g. `chat:128`. Used as the version-aware skip key
|
||||
/// dimension and recorded against every run.
|
||||
fn id(&self) -> &str;
|
||||
|
||||
/// Approximate prompt size in tokens (the cell dimension), recorded
|
||||
/// for reporting.
|
||||
fn prompt_size(&self) -> u32;
|
||||
|
||||
/// Whether this scenario should run against the given model. Default
|
||||
/// runs against everything; vision/audio scenarios will gate on
|
||||
/// [`ModelInfo::capabilities`].
|
||||
fn applies_to(&self, _model: &ModelInfo) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Issue one shaped request and measure it.
|
||||
async fn run(&self, ctx: &RunCtx) -> Result<ScenarioMetrics>;
|
||||
}
|
||||
|
||||
/// Build the active scenario set from config. One chat-latency scenario
|
||||
/// per configured prompt size.
|
||||
pub fn build_scenarios(cfg: &ScenarioConfig) -> Vec<Box<dyn Scenario>> {
|
||||
cfg.prompt_sizes
|
||||
.iter()
|
||||
.map(|&size| {
|
||||
Box::new(ChatLatencyScenario {
|
||||
id: format!("chat:{size}"),
|
||||
approx_prompt_tokens: size,
|
||||
}) as Box<dyn Scenario>
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Streamed single-request chat-completions latency probe — the batch-1
|
||||
/// regime bench.py measures.
|
||||
pub struct ChatLatencyScenario {
|
||||
id: String,
|
||||
approx_prompt_tokens: u32,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Scenario for ChatLatencyScenario {
|
||||
fn id(&self) -> &str {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn prompt_size(&self) -> u32 {
|
||||
self.approx_prompt_tokens
|
||||
}
|
||||
|
||||
async fn run(&self, ctx: &RunCtx) -> Result<ScenarioMetrics> {
|
||||
let prompt = build_prompt(self.approx_prompt_tokens);
|
||||
let payload = json!({
|
||||
"model": ctx.model_id,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": ctx.max_tokens,
|
||||
"temperature": 0,
|
||||
"stream": true,
|
||||
"stream_options": {"include_usage": true},
|
||||
});
|
||||
|
||||
let fut = stream_and_measure(ctx, &payload);
|
||||
tokio::time::timeout(ctx.timeout, fut)
|
||||
.await
|
||||
.map_err(|_| anyhow!("request timed out after {:?}", ctx.timeout))?
|
||||
}
|
||||
}
|
||||
|
||||
/// The SSE-timing core, ported from `bench.py::one_run`. Kept free of the
|
||||
/// `Scenario` trait so it's unit-testable against a mock byte stream.
|
||||
async fn stream_and_measure(
|
||||
ctx: &RunCtx<'_>,
|
||||
payload: &serde_json::Value,
|
||||
) -> Result<ScenarioMetrics> {
|
||||
let start = Instant::now();
|
||||
let resp = ctx
|
||||
.client
|
||||
.post(&ctx.chat_url)
|
||||
.json(payload)
|
||||
.send()
|
||||
.await
|
||||
.context("sending chat request")?;
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(anyhow!("upstream returned {status}: {}", body.trim()));
|
||||
}
|
||||
|
||||
let mut stream = resp.bytes_stream().eventsource();
|
||||
let mut first: Option<Instant> = None;
|
||||
let mut last: Option<Instant> = None;
|
||||
let mut chunk_count: u64 = 0;
|
||||
let mut prompt_tokens: Option<u64> = None;
|
||||
let mut completion_tokens: Option<u64> = None;
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
let event = event.context("reading SSE stream")?;
|
||||
let now = Instant::now();
|
||||
let data = event.data.trim();
|
||||
if data.is_empty() || data == "[DONE]" {
|
||||
continue;
|
||||
}
|
||||
let chunk: ChatCompletionChunk = match serde_json::from_str(data) {
|
||||
Ok(c) => c,
|
||||
Err(_) => continue, // tolerate non-JSON keepalive frames
|
||||
};
|
||||
if let Some(choice) = chunk.choices.first()
|
||||
&& choice
|
||||
.delta
|
||||
.get("content")
|
||||
.and_then(|c| c.as_str())
|
||||
.is_some_and(|s| !s.is_empty())
|
||||
{
|
||||
if first.is_none() {
|
||||
first = Some(now);
|
||||
}
|
||||
last = Some(now);
|
||||
chunk_count += 1;
|
||||
}
|
||||
if let Some(usage) = chunk.usage {
|
||||
prompt_tokens = Some(usage.prompt_tokens);
|
||||
completion_tokens = Some(usage.completion_tokens);
|
||||
}
|
||||
}
|
||||
let end = Instant::now();
|
||||
|
||||
let first = first.ok_or_else(|| anyhow!("no content chunks received"))?;
|
||||
|
||||
// neuron emits one SSE chunk per visible token, so chunk_count is an
|
||||
// engine-truth count when no usage frame is sent.
|
||||
let tokens = completion_tokens.filter(|&t| t > 0).unwrap_or(chunk_count);
|
||||
// decode rate is only meaningful over a real inter-chunk window.
|
||||
let window = last
|
||||
.filter(|&l| l > first)
|
||||
.map(|l| (l - first).as_secs_f64())
|
||||
.unwrap_or(0.0);
|
||||
Ok(ScenarioMetrics {
|
||||
ttft_s: (first - start).as_secs_f64(),
|
||||
decode_tps: if window > 0.2 {
|
||||
Some(tokens as f64 / window)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
total_s: (end - start).as_secs_f64(),
|
||||
prompt_tokens,
|
||||
completion_tokens: tokens,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn prompt_grows_with_token_target() {
|
||||
let small = build_prompt(128);
|
||||
let big = build_prompt(4096);
|
||||
assert!(big.len() > small.len());
|
||||
// ~4 chars/token + the trailing question.
|
||||
assert!(small.len() >= 128 * 4);
|
||||
assert!(small.ends_with("/no_think"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_floor_for_tiny_targets() {
|
||||
// max(approx,16) floor means even 0 yields a non-trivial prompt.
|
||||
let p = build_prompt(0);
|
||||
assert!(p.len() >= 16 * 4);
|
||||
}
|
||||
}
|
||||
@@ -1,400 +0,0 @@
|
||||
//! SQLite system-of-record. One row per measured iteration, keyed so a
|
||||
//! benchmark can be attributed to the exact neuron build that produced
|
||||
//! it. Replaces hand edits to `doc/benchmarks.md`.
|
||||
//!
|
||||
//! Calls are synchronous (SQLite is local and the sweep is batch-1
|
||||
//! sequential), so the connection is used inline between `await` points,
|
||||
//! never held across one.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use rusqlite::{Connection, params};
|
||||
use std::path::Path;
|
||||
|
||||
/// A single measured (or failed) iteration, with full provenance.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RunRecord {
|
||||
pub ts: String, // RFC3339
|
||||
// target
|
||||
pub target_name: String,
|
||||
pub target_kind: String,
|
||||
pub endpoint: String,
|
||||
// host (from /discovery)
|
||||
pub hostname: Option<String>,
|
||||
pub driver_version: Option<String>,
|
||||
pub cuda_version: Option<String>,
|
||||
pub gpus_json: Option<String>,
|
||||
// neuron build (from /version)
|
||||
pub git_sha: String,
|
||||
pub git_sha_long: Option<String>,
|
||||
pub package_version: String,
|
||||
pub git_dirty: bool,
|
||||
pub build_timestamp: Option<String>,
|
||||
pub rustc_version: Option<String>,
|
||||
pub profile: Option<String>,
|
||||
pub features_json: String,
|
||||
pub candle_version: Option<String>,
|
||||
// bench's own build
|
||||
pub bench_version: String,
|
||||
pub bench_sha: String,
|
||||
// model
|
||||
pub model_id: String,
|
||||
pub harness: String,
|
||||
pub capabilities_json: String,
|
||||
pub devices_json: String,
|
||||
// scenario
|
||||
pub scenario_id: String,
|
||||
pub prompt_size_approx: u32,
|
||||
pub prompt_tokens_actual: Option<u64>,
|
||||
pub max_tokens: u64,
|
||||
// metrics
|
||||
pub ttft_s: Option<f64>,
|
||||
pub decode_tps: Option<f64>,
|
||||
pub total_s: Option<f64>,
|
||||
pub completion_tokens: Option<u64>,
|
||||
// outcome
|
||||
pub ok: bool,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
pub struct Store {
|
||||
conn: Connection,
|
||||
}
|
||||
|
||||
impl Store {
|
||||
/// Open (creating parent dirs + schema as needed).
|
||||
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = path.as_ref();
|
||||
if let Some(parent) = path.parent()
|
||||
&& !parent.as_os_str().is_empty()
|
||||
{
|
||||
std::fs::create_dir_all(parent)
|
||||
.with_context(|| format!("creating db dir {}", parent.display()))?;
|
||||
}
|
||||
let conn = Connection::open(path)
|
||||
.with_context(|| format!("opening sqlite db {}", path.display()))?;
|
||||
Self::init(&conn)?;
|
||||
Ok(Store { conn })
|
||||
}
|
||||
|
||||
/// In-memory store for tests.
|
||||
#[cfg(test)]
|
||||
pub fn open_in_memory() -> Result<Self> {
|
||||
let conn = Connection::open_in_memory()?;
|
||||
Self::init(&conn)?;
|
||||
Ok(Store { conn })
|
||||
}
|
||||
|
||||
fn init(conn: &Connection) -> Result<()> {
|
||||
conn.execute_batch(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS runs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
ts TEXT NOT NULL,
|
||||
target_name TEXT NOT NULL,
|
||||
target_kind TEXT NOT NULL,
|
||||
endpoint TEXT NOT NULL,
|
||||
hostname TEXT,
|
||||
driver_version TEXT,
|
||||
cuda_version TEXT,
|
||||
gpus_json TEXT,
|
||||
git_sha TEXT NOT NULL,
|
||||
git_sha_long TEXT,
|
||||
package_version TEXT NOT NULL,
|
||||
git_dirty INTEGER NOT NULL,
|
||||
build_timestamp TEXT,
|
||||
rustc_version TEXT,
|
||||
profile TEXT,
|
||||
features_json TEXT NOT NULL,
|
||||
candle_version TEXT,
|
||||
bench_version TEXT NOT NULL,
|
||||
bench_sha TEXT NOT NULL,
|
||||
model_id TEXT NOT NULL,
|
||||
harness TEXT NOT NULL,
|
||||
capabilities_json TEXT NOT NULL,
|
||||
devices_json TEXT NOT NULL,
|
||||
scenario_id TEXT NOT NULL,
|
||||
prompt_size_approx INTEGER NOT NULL,
|
||||
prompt_tokens_actual INTEGER,
|
||||
max_tokens INTEGER NOT NULL,
|
||||
ttft_s REAL,
|
||||
decode_tps REAL,
|
||||
total_s REAL,
|
||||
completion_tokens INTEGER,
|
||||
ok INTEGER NOT NULL,
|
||||
error TEXT
|
||||
);
|
||||
-- The version-aware skip query keys on this tuple. scenario_id
|
||||
-- encodes the prompt size (chat:<n>), so it subsumes the cell.
|
||||
CREATE INDEX IF NOT EXISTS idx_runs_cell
|
||||
ON runs (target_name, git_sha, model_id, scenario_id, ok);
|
||||
"#,
|
||||
)
|
||||
.context("initialising sqlite schema")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Count successful samples already recorded for a cell. Only `ok`
|
||||
/// rows count toward the per-version target so transient failures
|
||||
/// don't permanently starve a cell.
|
||||
pub fn count_samples(
|
||||
&self,
|
||||
target_name: &str,
|
||||
git_sha: &str,
|
||||
model_id: &str,
|
||||
scenario_id: &str,
|
||||
) -> Result<u32> {
|
||||
let n: i64 = self.conn.query_row(
|
||||
"SELECT COUNT(*) FROM runs WHERE target_name=?1 AND git_sha=?2 \
|
||||
AND model_id=?3 AND scenario_id=?4 AND ok=1",
|
||||
params![target_name, git_sha, model_id, scenario_id],
|
||||
|row| row.get(0),
|
||||
)?;
|
||||
Ok(n as u32)
|
||||
}
|
||||
|
||||
pub fn insert_run(&self, r: &RunRecord) -> Result<()> {
|
||||
self.conn.execute(
|
||||
"INSERT INTO runs (
|
||||
ts, target_name, target_kind, endpoint,
|
||||
hostname, driver_version, cuda_version, gpus_json,
|
||||
git_sha, git_sha_long, package_version, git_dirty,
|
||||
build_timestamp, rustc_version, profile, features_json, candle_version,
|
||||
bench_version, bench_sha,
|
||||
model_id, harness, capabilities_json, devices_json,
|
||||
scenario_id, prompt_size_approx, prompt_tokens_actual, max_tokens,
|
||||
ttft_s, decode_tps, total_s, completion_tokens,
|
||||
ok, error
|
||||
) VALUES (
|
||||
?1, ?2, ?3, ?4,
|
||||
?5, ?6, ?7, ?8,
|
||||
?9, ?10, ?11, ?12,
|
||||
?13, ?14, ?15, ?16, ?17,
|
||||
?18, ?19,
|
||||
?20, ?21, ?22, ?23,
|
||||
?24, ?25, ?26, ?27,
|
||||
?28, ?29, ?30, ?31,
|
||||
?32, ?33
|
||||
)",
|
||||
params![
|
||||
r.ts,
|
||||
r.target_name,
|
||||
r.target_kind,
|
||||
r.endpoint,
|
||||
r.hostname,
|
||||
r.driver_version,
|
||||
r.cuda_version,
|
||||
r.gpus_json,
|
||||
r.git_sha,
|
||||
r.git_sha_long,
|
||||
r.package_version,
|
||||
r.git_dirty as i64,
|
||||
r.build_timestamp,
|
||||
r.rustc_version,
|
||||
r.profile,
|
||||
r.features_json,
|
||||
r.candle_version,
|
||||
r.bench_version,
|
||||
r.bench_sha,
|
||||
r.model_id,
|
||||
r.harness,
|
||||
r.capabilities_json,
|
||||
r.devices_json,
|
||||
r.scenario_id,
|
||||
r.prompt_size_approx,
|
||||
r.prompt_tokens_actual,
|
||||
r.max_tokens,
|
||||
r.ttft_s,
|
||||
r.decode_tps,
|
||||
r.total_s,
|
||||
r.completion_tokens,
|
||||
r.ok as i64,
|
||||
r.error,
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// One reportable cell: the median metrics over the most-recently-seen
|
||||
/// build SHA for each (target, model, scenario).
|
||||
pub fn report_rows(&self) -> Result<Vec<ReportRow>> {
|
||||
// For each (target, model, scenario), find the SHA of the latest
|
||||
// successful run, then median that SHA's samples.
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT target_name, model_id, scenario_id, prompt_size_approx, git_sha,
|
||||
ttft_s, decode_tps, total_s, prompt_tokens_actual
|
||||
FROM runs
|
||||
WHERE ok=1
|
||||
ORDER BY target_name, model_id, scenario_id, id",
|
||||
)?;
|
||||
let rows = stmt.query_map([], |row| {
|
||||
Ok(RawRow {
|
||||
target_name: row.get(0)?,
|
||||
model_id: row.get(1)?,
|
||||
scenario_id: row.get(2)?,
|
||||
prompt_size_approx: row.get(3)?,
|
||||
git_sha: row.get(4)?,
|
||||
ttft_s: row.get(5)?,
|
||||
decode_tps: row.get(6)?,
|
||||
total_s: row.get(7)?,
|
||||
prompt_tokens_actual: row.get(8)?,
|
||||
})
|
||||
})?;
|
||||
let raws: Vec<RawRow> = rows.collect::<rusqlite::Result<_>>()?;
|
||||
Ok(aggregate(raws))
|
||||
}
|
||||
}
|
||||
|
||||
struct RawRow {
|
||||
target_name: String,
|
||||
model_id: String,
|
||||
scenario_id: String,
|
||||
prompt_size_approx: u32,
|
||||
git_sha: String,
|
||||
ttft_s: Option<f64>,
|
||||
decode_tps: Option<f64>,
|
||||
total_s: Option<f64>,
|
||||
prompt_tokens_actual: Option<u64>,
|
||||
}
|
||||
|
||||
/// An aggregated cell ready for the report table.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ReportRow {
|
||||
pub target_name: String,
|
||||
pub model_id: String,
|
||||
pub scenario_id: String,
|
||||
pub prompt_size_approx: u32,
|
||||
pub git_sha: String,
|
||||
pub prompt_tokens: Option<u64>,
|
||||
pub ttft_s_median: Option<f64>,
|
||||
pub decode_tps_median: Option<f64>,
|
||||
pub total_s_median: Option<f64>,
|
||||
pub samples: usize,
|
||||
}
|
||||
|
||||
/// Group by (target, model, scenario), keep only the latest SHA's rows
|
||||
/// (latest = the SHA of the last-inserted row, since input is id-ordered),
|
||||
/// and median each metric.
|
||||
fn aggregate(raws: Vec<RawRow>) -> Vec<ReportRow> {
|
||||
use std::collections::BTreeMap;
|
||||
// key -> (latest_sha, rows for that sha)
|
||||
let mut groups: BTreeMap<(String, String, String), Vec<RawRow>> = BTreeMap::new();
|
||||
for r in raws {
|
||||
groups
|
||||
.entry((
|
||||
r.target_name.clone(),
|
||||
r.model_id.clone(),
|
||||
r.scenario_id.clone(),
|
||||
))
|
||||
.or_default()
|
||||
.push(r);
|
||||
}
|
||||
let mut out = Vec::new();
|
||||
for ((target_name, model_id, scenario_id), rows) in groups {
|
||||
// id-ordered, so the last row carries the latest SHA.
|
||||
let latest_sha = rows.last().map(|r| r.git_sha.clone()).unwrap_or_default();
|
||||
let cell: Vec<&RawRow> = rows.iter().filter(|r| r.git_sha == latest_sha).collect();
|
||||
let prompt_size_approx = cell.first().map(|r| r.prompt_size_approx).unwrap_or(0);
|
||||
out.push(ReportRow {
|
||||
target_name,
|
||||
model_id,
|
||||
scenario_id,
|
||||
prompt_size_approx,
|
||||
git_sha: latest_sha,
|
||||
prompt_tokens: cell.iter().find_map(|r| r.prompt_tokens_actual),
|
||||
ttft_s_median: median(cell.iter().filter_map(|r| r.ttft_s)),
|
||||
decode_tps_median: median(cell.iter().filter_map(|r| r.decode_tps)),
|
||||
total_s_median: median(cell.iter().filter_map(|r| r.total_s)),
|
||||
samples: cell.len(),
|
||||
});
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn median(values: impl Iterator<Item = f64>) -> Option<f64> {
|
||||
let mut v: Vec<f64> = values.collect();
|
||||
if v.is_empty() {
|
||||
return None;
|
||||
}
|
||||
v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
|
||||
// lo == hi for odd lengths (the middle element); they straddle the
|
||||
// centre for even lengths. Avoids a `% 2` branch.
|
||||
let lo = (v.len() - 1) / 2;
|
||||
let hi = v.len() / 2;
|
||||
Some((v[lo] + v[hi]) / 2.0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn rec(target: &str, sha: &str, model: &str, scenario: &str, ok: bool) -> RunRecord {
|
||||
RunRecord {
|
||||
ts: "2026-06-13T00:00:00Z".into(),
|
||||
target_name: target.into(),
|
||||
target_kind: "neuron".into(),
|
||||
endpoint: "http://x:13131".into(),
|
||||
hostname: Some("x".into()),
|
||||
driver_version: None,
|
||||
cuda_version: None,
|
||||
gpus_json: None,
|
||||
git_sha: sha.into(),
|
||||
git_sha_long: None,
|
||||
package_version: "0.1.16".into(),
|
||||
git_dirty: false,
|
||||
build_timestamp: None,
|
||||
rustc_version: None,
|
||||
profile: None,
|
||||
features_json: "[]".into(),
|
||||
candle_version: None,
|
||||
bench_version: "0.1.16".into(),
|
||||
bench_sha: "deadbee".into(),
|
||||
model_id: model.into(),
|
||||
harness: "candle".into(),
|
||||
capabilities_json: "[]".into(),
|
||||
devices_json: "[]".into(),
|
||||
scenario_id: scenario.into(),
|
||||
prompt_size_approx: 128,
|
||||
prompt_tokens_actual: Some(130),
|
||||
max_tokens: 256,
|
||||
ttft_s: Some(0.1),
|
||||
decode_tps: Some(50.0),
|
||||
total_s: Some(1.0),
|
||||
completion_tokens: Some(50),
|
||||
ok,
|
||||
error: if ok { None } else { Some("boom".into()) },
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn counts_only_successful_samples() {
|
||||
let s = Store::open_in_memory().unwrap();
|
||||
s.insert_run(&rec("beast", "abc", "m", "chat:128", true))
|
||||
.unwrap();
|
||||
s.insert_run(&rec("beast", "abc", "m", "chat:128", true))
|
||||
.unwrap();
|
||||
s.insert_run(&rec("beast", "abc", "m", "chat:128", false))
|
||||
.unwrap();
|
||||
assert_eq!(s.count_samples("beast", "abc", "m", "chat:128").unwrap(), 2);
|
||||
// Different SHA is a different cell.
|
||||
assert_eq!(s.count_samples("beast", "xyz", "m", "chat:128").unwrap(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_uses_latest_sha_per_cell() {
|
||||
let s = Store::open_in_memory().unwrap();
|
||||
// old build
|
||||
s.insert_run(&rec("beast", "old", "m", "chat:128", true))
|
||||
.unwrap();
|
||||
// new build, two samples
|
||||
let mut r = rec("beast", "new", "m", "chat:128", true);
|
||||
r.ttft_s = Some(0.2);
|
||||
s.insert_run(&r).unwrap();
|
||||
r.ttft_s = Some(0.4);
|
||||
s.insert_run(&r).unwrap();
|
||||
let rows = s.report_rows().unwrap();
|
||||
assert_eq!(rows.len(), 1);
|
||||
assert_eq!(rows[0].git_sha, "new");
|
||||
assert_eq!(rows[0].samples, 2);
|
||||
assert!((rows[0].ttft_s_median.unwrap() - 0.3).abs() < 1e-9);
|
||||
}
|
||||
}
|
||||
@@ -1,250 +0,0 @@
|
||||
//! The version-aware sweep loop.
|
||||
//!
|
||||
//! Each sweep visits every configured target, polls its build identity
|
||||
//! and warm models, and tops up benchmark samples per
|
||||
//! (target, build SHA, model, scenario) to `samples_per_version`. Cells
|
||||
//! already at target are skipped — so once every neuron's current build
|
||||
//! is fully sampled, sweeps cost only the cheap metadata polls until a
|
||||
//! new SHA ships. Runs are recorded to SQLite with full provenance.
|
||||
|
||||
use crate::client::TargetClient;
|
||||
use crate::config::{BenchConfig, TargetConfig, TargetKind};
|
||||
use crate::scenario::{RunCtx, build_scenarios};
|
||||
use crate::store::{RunRecord, Store};
|
||||
use anyhow::Result;
|
||||
use cortex_core::build_info::BuildInfo;
|
||||
use cortex_core::discovery::DiscoveryResponse;
|
||||
use cortex_core::harness::ModelInfo;
|
||||
|
||||
/// helexa-bench's own build version.
|
||||
fn bench_version() -> String {
|
||||
env!("CARGO_PKG_VERSION").to_string()
|
||||
}
|
||||
|
||||
/// helexa-bench's own build SHA, injected by CI via `HELEXA_BUILD_SHA`
|
||||
/// at compile time; `"unknown"` for ad-hoc local builds.
|
||||
fn bench_sha() -> String {
|
||||
option_env!("HELEXA_BUILD_SHA")
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or("unknown")
|
||||
.to_string()
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct SweepSummary {
|
||||
pub measured: usize,
|
||||
pub skipped: usize,
|
||||
pub failed: usize,
|
||||
pub targets_unreachable: usize,
|
||||
}
|
||||
|
||||
pub struct Sweeper {
|
||||
cfg: BenchConfig,
|
||||
client: TargetClient,
|
||||
store: Store,
|
||||
}
|
||||
|
||||
impl Sweeper {
|
||||
pub fn new(cfg: BenchConfig) -> Result<Self> {
|
||||
let client = TargetClient::new(cfg.bench.request_timeout())?;
|
||||
let store = Store::open(&cfg.bench.db_path)?;
|
||||
Ok(Sweeper { cfg, client, store })
|
||||
}
|
||||
|
||||
/// Run sweeps forever, pausing `sweep_interval` between them.
|
||||
pub async fn run_forever(&self) -> ! {
|
||||
loop {
|
||||
match self.run_once().await {
|
||||
Ok(s) => tracing::info!(
|
||||
measured = s.measured,
|
||||
skipped = s.skipped,
|
||||
failed = s.failed,
|
||||
unreachable = s.targets_unreachable,
|
||||
"sweep complete"
|
||||
),
|
||||
Err(e) => tracing::error!(error = %format!("{e:#}"), "sweep errored"),
|
||||
}
|
||||
tracing::debug!(
|
||||
secs = self.cfg.bench.sweep_interval_secs,
|
||||
"sleeping until next sweep"
|
||||
);
|
||||
tokio::time::sleep(self.cfg.bench.sweep_interval()).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// One full pass over all targets.
|
||||
pub async fn run_once(&self) -> Result<SweepSummary> {
|
||||
let mut summary = SweepSummary::default();
|
||||
for target in &self.cfg.targets {
|
||||
if let Err(e) = self.sweep_target(target, &mut summary).await {
|
||||
summary.targets_unreachable += 1;
|
||||
tracing::warn!(target = %target.name, error = %format!("{e:#}"), "target skipped");
|
||||
}
|
||||
}
|
||||
Ok(summary)
|
||||
}
|
||||
|
||||
async fn sweep_target(&self, target: &TargetConfig, summary: &mut SweepSummary) -> Result<()> {
|
||||
let build = self.client.fetch_version(target).await?;
|
||||
let discovery = self.client.fetch_discovery(target).await.unwrap_or(None);
|
||||
let models = self.client.warm_models(target).await?;
|
||||
|
||||
tracing::info!(
|
||||
target = %target.name,
|
||||
sha = %build.git_sha,
|
||||
warm_models = models.len(),
|
||||
"sweeping target"
|
||||
);
|
||||
|
||||
let scenarios = build_scenarios(&self.cfg.scenarios);
|
||||
for model in &models {
|
||||
for scenario in scenarios.iter().filter(|s| s.applies_to(model)) {
|
||||
let have = self.store.count_samples(
|
||||
&target.name,
|
||||
&build.git_sha,
|
||||
&model.id,
|
||||
scenario.id(),
|
||||
)?;
|
||||
let need = self.cfg.bench.samples_per_version.saturating_sub(have);
|
||||
if need == 0 {
|
||||
summary.skipped += 1;
|
||||
tracing::debug!(
|
||||
target = %target.name, model = %model.id, scenario = scenario.id(),
|
||||
sha = %build.git_sha, "cell already satisfied, skipping"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let ctx = RunCtx {
|
||||
client: self.client.http(),
|
||||
chat_url: self.client.chat_url(target),
|
||||
model_id: model.id.clone(),
|
||||
max_tokens: self.cfg.scenarios.max_tokens,
|
||||
timeout: self.cfg.bench.request_timeout(),
|
||||
};
|
||||
|
||||
// One unmeasured warmup when the cell is empty (matches
|
||||
// bench.py — first run after a load hits cold caches).
|
||||
if have == 0 {
|
||||
tracing::debug!(model = %model.id, scenario = scenario.id(), "warmup run");
|
||||
let _ = scenario.run(&ctx).await;
|
||||
}
|
||||
|
||||
for i in 0..need {
|
||||
match scenario.run(&ctx).await {
|
||||
Ok(m) => {
|
||||
let rec = self.build_record(
|
||||
target,
|
||||
&build,
|
||||
discovery.as_ref(),
|
||||
model,
|
||||
scenario.id(),
|
||||
scenario.prompt_size(),
|
||||
Ok(&m),
|
||||
);
|
||||
self.store.insert_run(&rec)?;
|
||||
summary.measured += 1;
|
||||
tracing::info!(
|
||||
target = %target.name, model = %model.id, scenario = scenario.id(),
|
||||
ttft_s = m.ttft_s, decode_tps = ?m.decode_tps, total_s = m.total_s,
|
||||
"{}/{} recorded", have + i + 1, self.cfg.bench.samples_per_version
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
let msg = format!("{e:#}");
|
||||
let rec = self.build_record(
|
||||
target,
|
||||
&build,
|
||||
discovery.as_ref(),
|
||||
model,
|
||||
scenario.id(),
|
||||
scenario.prompt_size(),
|
||||
Err(&msg),
|
||||
);
|
||||
self.store.insert_run(&rec)?;
|
||||
summary.failed += 1;
|
||||
tracing::warn!(
|
||||
target = %target.name, model = %model.id, scenario = scenario.id(),
|
||||
error = %msg, "iteration failed"
|
||||
);
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(self.cfg.bench.iteration_pause()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn build_record(
|
||||
&self,
|
||||
target: &TargetConfig,
|
||||
build: &BuildInfo,
|
||||
discovery: Option<&DiscoveryResponse>,
|
||||
model: &ModelInfo,
|
||||
scenario_id: &str,
|
||||
prompt_size: u32,
|
||||
result: Result<&crate::scenario::ScenarioMetrics, &str>,
|
||||
) -> RunRecord {
|
||||
let (ok, error, ttft, decode, total, prompt_tokens, completion) = match result {
|
||||
Ok(m) => (
|
||||
true,
|
||||
None,
|
||||
Some(m.ttft_s),
|
||||
m.decode_tps,
|
||||
Some(m.total_s),
|
||||
m.prompt_tokens,
|
||||
Some(m.completion_tokens),
|
||||
),
|
||||
Err(e) => (false, Some(e.to_string()), None, None, None, None, None),
|
||||
};
|
||||
|
||||
RunRecord {
|
||||
ts: chrono::Utc::now().to_rfc3339(),
|
||||
target_name: target.name.clone(),
|
||||
target_kind: kind_str(target.kind).to_string(),
|
||||
endpoint: target.endpoint.clone(),
|
||||
hostname: discovery.map(|d| d.hostname.clone()),
|
||||
driver_version: discovery.and_then(|d| d.driver_version.clone()),
|
||||
cuda_version: discovery.and_then(|d| d.cuda_version.clone()),
|
||||
gpus_json: discovery
|
||||
.map(|d| serde_json::to_string(&d.devices).unwrap_or_else(|_| "[]".to_string())),
|
||||
git_sha: build.git_sha.clone(),
|
||||
git_sha_long: build.git_sha_long.clone(),
|
||||
package_version: build.package_version.clone(),
|
||||
git_dirty: build.git_dirty,
|
||||
build_timestamp: build.build_timestamp.clone(),
|
||||
rustc_version: build.rustc_version.clone(),
|
||||
profile: build.profile.clone(),
|
||||
features_json: serde_json::to_string(&build.features)
|
||||
.unwrap_or_else(|_| "[]".to_string()),
|
||||
candle_version: build.candle_version.clone(),
|
||||
bench_version: bench_version(),
|
||||
bench_sha: bench_sha(),
|
||||
model_id: model.id.clone(),
|
||||
harness: model.harness.clone(),
|
||||
capabilities_json: serde_json::to_string(&model.capabilities)
|
||||
.unwrap_or_else(|_| "[]".to_string()),
|
||||
devices_json: serde_json::to_string(&model.devices)
|
||||
.unwrap_or_else(|_| "[]".to_string()),
|
||||
scenario_id: scenario_id.to_string(),
|
||||
prompt_size_approx: prompt_size,
|
||||
prompt_tokens_actual: prompt_tokens,
|
||||
max_tokens: self.cfg.scenarios.max_tokens,
|
||||
ttft_s: ttft,
|
||||
decode_tps: decode,
|
||||
total_s: total,
|
||||
completion_tokens: completion,
|
||||
ok,
|
||||
error,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn kind_str(kind: TargetKind) -> &'static str {
|
||||
match kind {
|
||||
TargetKind::Neuron => "neuron",
|
||||
TargetKind::Openai => "openai",
|
||||
}
|
||||
}
|
||||
@@ -1,132 +0,0 @@
|
||||
//! End-to-end sweep against a mock neuron: a sweep records samples, a
|
||||
//! second sweep skips the satisfied cell, and bumping the reported build
|
||||
//! SHA resumes fresh sampling.
|
||||
|
||||
use axum::Router;
|
||||
use axum::extract::State;
|
||||
use axum::http::header;
|
||||
use axum::response::{IntoResponse, Json};
|
||||
use axum::routing::{get, post};
|
||||
use helexa_bench::config::{BenchConfig, BenchSettings, ScenarioConfig, TargetConfig, TargetKind};
|
||||
use helexa_bench::sweep::Sweeper;
|
||||
use serde_json::json;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MockState {
|
||||
sha: Arc<Mutex<String>>,
|
||||
}
|
||||
|
||||
async fn version(State(s): State<MockState>) -> Json<serde_json::Value> {
|
||||
let sha = s.sha.lock().unwrap().clone();
|
||||
Json(json!({
|
||||
"package_version": "0.1.16",
|
||||
"git_sha": sha,
|
||||
"git_dirty": false,
|
||||
"features": ["cuda", "cudnn"],
|
||||
"candle_version": "0.10.2",
|
||||
}))
|
||||
}
|
||||
|
||||
async fn discovery() -> Json<serde_json::Value> {
|
||||
Json(json!({
|
||||
"hostname": "mock-beast",
|
||||
"os": "Linux",
|
||||
"kernel": "6.19.0",
|
||||
"cuda_version": "13.0",
|
||||
"driver_version": "580.159",
|
||||
"devices": [{"index": 0, "name": "RTX 5090", "vram_total_mb": 32614, "compute_capability": "12.0"}],
|
||||
"harnesses": ["candle"],
|
||||
}))
|
||||
}
|
||||
|
||||
async fn models() -> Json<serde_json::Value> {
|
||||
Json(json!([
|
||||
{"id": "Qwen/Qwen3.6-27B", "harness": "candle", "status": "loaded", "devices": [0], "capabilities": ["text"]},
|
||||
// A non-warm model the bench must ignore.
|
||||
{"id": "Qwen/cold", "harness": "candle", "status": "recovering", "devices": [0]},
|
||||
]))
|
||||
}
|
||||
|
||||
async fn chat() -> impl IntoResponse {
|
||||
let body = concat!(
|
||||
"data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n",
|
||||
"data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\" world\"},\"finish_reason\":null}]}\n\n",
|
||||
"data: {\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":130,\"completion_tokens\":2,\"total_tokens\":132}}\n\n",
|
||||
"data: [DONE]\n\n",
|
||||
);
|
||||
([(header::CONTENT_TYPE, "text/event-stream")], body)
|
||||
}
|
||||
|
||||
async fn spawn_mock(sha: &str) -> (String, Arc<Mutex<String>>) {
|
||||
let shared = Arc::new(Mutex::new(sha.to_string()));
|
||||
let state = MockState {
|
||||
sha: shared.clone(),
|
||||
};
|
||||
let app = Router::new()
|
||||
.route("/version", get(version))
|
||||
.route("/discovery", get(discovery))
|
||||
.route("/models", get(models))
|
||||
.route("/v1/chat/completions", post(chat))
|
||||
.with_state(state);
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
(format!("http://{addr}"), shared)
|
||||
}
|
||||
|
||||
fn config_for(endpoint: String, db_path: String) -> BenchConfig {
|
||||
BenchConfig {
|
||||
bench: BenchSettings {
|
||||
sweep_interval_secs: 1,
|
||||
samples_per_version: 2,
|
||||
iteration_pause_secs: 0,
|
||||
request_timeout_secs: 30,
|
||||
db_path,
|
||||
},
|
||||
scenarios: ScenarioConfig {
|
||||
prompt_sizes: vec![128], // single scenario keeps assertions simple
|
||||
max_tokens: 16,
|
||||
},
|
||||
targets: vec![TargetConfig {
|
||||
name: "mock".into(),
|
||||
kind: TargetKind::Neuron,
|
||||
endpoint,
|
||||
label: None,
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sweep_records_skips_and_resumes_on_new_sha() {
|
||||
let (endpoint, sha_handle) = spawn_mock("aaaaaaa").await;
|
||||
|
||||
// Unique db path per run (bound port is unique).
|
||||
let port = endpoint.rsplit(':').next().unwrap();
|
||||
let db_path = std::env::temp_dir().join(format!("helexa-bench-it-{port}.sqlite"));
|
||||
let _ = std::fs::remove_file(&db_path);
|
||||
let db_str = db_path.to_string_lossy().to_string();
|
||||
|
||||
let sweeper = Sweeper::new(config_for(endpoint, db_str)).unwrap();
|
||||
|
||||
// First sweep: one warm model × one scenario × 2 samples.
|
||||
let s1 = sweeper.run_once().await.unwrap();
|
||||
assert_eq!(s1.measured, 2, "should record samples_per_version samples");
|
||||
assert_eq!(s1.skipped, 0);
|
||||
assert_eq!(s1.failed, 0);
|
||||
|
||||
// Second sweep at same SHA: cell satisfied, nothing measured.
|
||||
let s2 = sweeper.run_once().await.unwrap();
|
||||
assert_eq!(s2.measured, 0, "satisfied cell must be skipped");
|
||||
assert_eq!(s2.skipped, 1);
|
||||
|
||||
// Bump the reported build SHA: a new cell → fresh sampling resumes.
|
||||
*sha_handle.lock().unwrap() = "bbbbbbb".to_string();
|
||||
let s3 = sweeper.run_once().await.unwrap();
|
||||
assert_eq!(s3.measured, 2, "new SHA must resume sampling");
|
||||
assert_eq!(s3.skipped, 0);
|
||||
|
||||
let _ = std::fs::remove_file(&db_path);
|
||||
}
|
||||
@@ -12,36 +12,6 @@ path = "src/lib.rs"
|
||||
name = "neuron"
|
||||
path = "src/main.rs"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
# Enables CUDA acceleration in candle and the cudarc/nccl bindings the
|
||||
# TP worker pool uses. Without this feature, candle compiles for CPU
|
||||
# only, Device::new_cuda calls fall back to CPU, and TP Init/sanity
|
||||
# requests return Error{kind="cuda_feature_not_enabled"}.
|
||||
cuda = [
|
||||
"candle-core/cuda",
|
||||
"candle-core/nccl",
|
||||
"candle-nn/cuda",
|
||||
"candle-transformers/cuda",
|
||||
"dep:cudarc",
|
||||
"dep:half",
|
||||
"dep:cudaforge",
|
||||
]
|
||||
# Use cuDNN for convolution / attention kernels. Requires CUDA.
|
||||
cudnn = [
|
||||
"cuda",
|
||||
"candle-core/cudnn",
|
||||
"candle-nn/cudnn",
|
||||
"candle-transformers/cudnn",
|
||||
]
|
||||
# FlashAttention kernels. Requires CUDA.
|
||||
flash-attn = [
|
||||
"cuda",
|
||||
"candle-transformers/flash-attn",
|
||||
]
|
||||
# Reserved for GPU-only integration tests in later stages.
|
||||
cuda-integration = ["cuda"]
|
||||
|
||||
[dependencies]
|
||||
cortex-core.workspace = true
|
||||
tokio.workspace = true
|
||||
@@ -54,70 +24,9 @@ tracing-subscriber.workspace = true
|
||||
anyhow.workspace = true
|
||||
async-trait.workspace = true
|
||||
clap.workspace = true
|
||||
thiserror.workspace = true
|
||||
futures.workspace = true
|
||||
tokio-stream.workspace = true
|
||||
figment.workspace = true
|
||||
toml.workspace = true
|
||||
|
||||
# Parallel in-situ quantization (#1): fans candle's per-block k-quant
|
||||
# math across the CPU pool at model-load time. Already in the tree
|
||||
# transitively via candle-core.
|
||||
rayon = "1"
|
||||
|
||||
# candle for in-process inference. CUDA support is gated behind the
|
||||
# crate's `cuda` feature (default off) so the workspace builds on
|
||||
# non-CUDA hosts and CI runners.
|
||||
candle-core = "0.10.2"
|
||||
candle-nn = "0.10.2"
|
||||
candle-transformers = "0.10.2"
|
||||
# Direct dep on cudarc (matching candle's transitive version) so the
|
||||
# TP worker pool can call cudarc::nccl::{Comm, Id} directly. Gated on
|
||||
# the `cuda` feature; same toolchain requirement as candle's CUDA path.
|
||||
cudarc = { version = "0.19", optional = true, default-features = false, features = ["nccl", "cuda-version-from-build-system"] }
|
||||
# Used by the AllReduce CustomOp1 to type-dispatch on bf16/f16 candle
|
||||
# storages. Matches candle-core's pinned major version to avoid double-
|
||||
# compiling the `half` crate at conflicting versions.
|
||||
half = { version = "2.5", optional = true }
|
||||
tokenizers = { version = "0.22", default-features = false, features = ["onig"] }
|
||||
hf-hub = { version = "0.4", features = ["tokio"] }
|
||||
# Jinja-compatible template renderer for the model's chat template
|
||||
# (standalone `chat_template.jinja` or `tokenizer_config.json::chat_template`).
|
||||
# Hugging Face's chat templates lean on Python string semantics; we
|
||||
# bridge them with `minijinja-contrib`'s `pycompat` callback (str
|
||||
# methods like `startswith`/`split`/`strip`) plus a `raise_exception`
|
||||
# global. Features: `builtins` for `is defined` / `default`; `json`
|
||||
# for `tojson`; `serde` so we can hand it a serde_json::Value context.
|
||||
minijinja = { version = "2", features = ["builtins", "json", "serde"] }
|
||||
# Python-compatibility shim: the Qwen3-VL / Qwen3.6 template uses
|
||||
# `content.startswith(...)`, `.endswith(...)`, `.split(...)`,
|
||||
# `.rstrip(...)`, `.lstrip(...)` — Python str methods minijinja doesn't
|
||||
# implement natively. `pycompat::unknown_method_callback` supplies them.
|
||||
minijinja-contrib = { version = "2", features = ["pycompat"] }
|
||||
# Direct dep on `safetensors` (re-exported by candle but its `TensorView`
|
||||
# / `slice::IndexOp` types are public-but-not-re-exported). Used by the
|
||||
# tp `fused_load` module to read per-rank slices of fused QKV tensors
|
||||
# without materialising the full tensor on device.
|
||||
safetensors = "0.7"
|
||||
# Vision capability for Qwen3.6 (Stage A of the vision plan in
|
||||
# doc/vision-qwen3_6-spec.md). `image` decodes PNG/JPEG/etc from
|
||||
# the bytes embedded in `data:image/...;base64,...` content parts;
|
||||
# `base64` does the URI decode. Default-features off on `image` to
|
||||
# avoid pulling in audio/video formats we don't need.
|
||||
image = { version = "0.25", default-features = false, features = ["png", "jpeg", "webp", "bmp", "gif"] }
|
||||
base64 = "0.22"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["test-util"] }
|
||||
reqwest.workspace = true
|
||||
tempfile = "3"
|
||||
|
||||
[build-dependencies]
|
||||
# Used by `build.rs` to compile `src/cuda/*.cu` into `libneuroncuda.a`
|
||||
# under the `cuda` feature. Matches mistralrs's upstream build setup
|
||||
# (their `mistralrs-core/build.rs` uses the same constructor).
|
||||
cudaforge = { version = "0.1", optional = true }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
# Skip the CUDA path on docs.rs (it lacks nvcc).
|
||||
no-default-features = true
|
||||
|
||||
@@ -1,196 +0,0 @@
|
||||
//! Build script: capture build/version metadata for `GET /version`,
|
||||
//! and (under the `cuda` feature) compile the CUDA kernels in
|
||||
//! `src/cuda/*.cu` into a static library and link it.
|
||||
//!
|
||||
//! The CUDA portion is patterned on
|
||||
//! `EricLBuehler/mistral.rs::mistralrs-core/build.rs` — same
|
||||
//! `cudaforge::KernelBuilder` invocation, same NVCC flag set.
|
||||
|
||||
use std::process::Command;
|
||||
|
||||
fn main() {
|
||||
emit_build_metadata();
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
use std::path::PathBuf;
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
println!("cargo:rerun-if-changed=src/cuda/");
|
||||
|
||||
let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
|
||||
|
||||
let mut builder = cudaforge::KernelBuilder::new()
|
||||
.source_glob("src/cuda/*.cu")
|
||||
.out_dir(&build_dir)
|
||||
.arg("-std=c++17")
|
||||
.arg("-O3")
|
||||
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
||||
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
||||
.arg("--expt-relaxed-constexpr")
|
||||
.arg("--expt-extended-lambda")
|
||||
.arg("--use_fast_math")
|
||||
.arg("--compiler-options")
|
||||
.arg("-fPIC");
|
||||
|
||||
// sm_<80 doesn't have bf16 intrinsics for WMMA — gate the
|
||||
// bf16-only kernels off in that case. (Mirrors upstream.)
|
||||
if let Some(compute_cap) = builder.get_compute_cap()
|
||||
&& compute_cap < 80
|
||||
{
|
||||
builder = builder.arg("-DNO_BF16_KERNEL");
|
||||
}
|
||||
|
||||
let target = std::env::var("TARGET").unwrap();
|
||||
let out_file = if target.contains("msvc") {
|
||||
build_dir.join("neuroncuda.lib")
|
||||
} else {
|
||||
build_dir.join("libneuroncuda.a")
|
||||
};
|
||||
|
||||
builder
|
||||
.build_lib(out_file)
|
||||
.expect("neuron cuda build failed");
|
||||
println!("cargo:rustc-link-search={}", build_dir.display());
|
||||
println!("cargo:rustc-link-lib=neuroncuda");
|
||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||
|
||||
if target.contains("msvc") {
|
||||
// No extra runtime library needed.
|
||||
} else if target.contains("apple")
|
||||
|| target.contains("freebsd")
|
||||
|| target.contains("openbsd")
|
||||
{
|
||||
println!("cargo:rustc-link-lib=dylib=c++");
|
||||
} else if target.contains("android") {
|
||||
println!("cargo:rustc-link-lib=dylib=c++_shared");
|
||||
} else {
|
||||
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Emit `cargo:rustc-env=` vars consumed by `env!()` in `src/version.rs`
|
||||
/// so the daemon can report its own build identity from `GET /version`.
|
||||
///
|
||||
/// We re-run only when HEAD moves or the SHA override changes — not on
|
||||
/// every compile — so the captured timestamp is stable for a given
|
||||
/// build input rather than churning on each `cargo build`.
|
||||
fn emit_build_metadata() {
|
||||
println!("cargo:rerun-if-env-changed=HELEXA_BUILD_SHA");
|
||||
println!("cargo:rerun-if-changed=.git/HEAD");
|
||||
// A detached/normal HEAD points at a ref whose file is what actually
|
||||
// changes on commit; watch the packed-refs fallback too.
|
||||
println!("cargo:rerun-if-changed=.git/packed-refs");
|
||||
|
||||
// SHA: prefer the CI/RPM-injected override (tarball builds have no
|
||||
// .git), then fall back to git, then to "unknown".
|
||||
let (sha_short, sha_long, dirty) = match std::env::var("HELEXA_BUILD_SHA") {
|
||||
Ok(s) if !s.trim().is_empty() => {
|
||||
let s = s.trim().to_string();
|
||||
let short = s.chars().take(7).collect::<String>();
|
||||
(short, Some(s), false)
|
||||
}
|
||||
_ => {
|
||||
let long = git(&["rev-parse", "HEAD"]);
|
||||
let short = git(&["rev-parse", "--short", "HEAD"]);
|
||||
let dirty = git(&["status", "--porcelain"])
|
||||
.map(|s| !s.trim().is_empty())
|
||||
.unwrap_or(false);
|
||||
match short {
|
||||
Some(short) => (short, long, dirty),
|
||||
None => ("unknown".to_string(), None, false),
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("cargo:rustc-env=HELEXA_GIT_SHA={sha_short}");
|
||||
println!(
|
||||
"cargo:rustc-env=HELEXA_GIT_SHA_LONG={}",
|
||||
sha_long.unwrap_or_default()
|
||||
);
|
||||
println!("cargo:rustc-env=HELEXA_GIT_DIRTY={dirty}");
|
||||
|
||||
// RFC3339 build timestamp. `date` is universally present on the
|
||||
// Linux hosts neuron targets; empty if it ever isn't.
|
||||
let ts = Command::new("date")
|
||||
.args(["-u", "+%Y-%m-%dT%H:%M:%SZ"])
|
||||
.output()
|
||||
.ok()
|
||||
.filter(|o| o.status.success())
|
||||
.map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
|
||||
.unwrap_or_default();
|
||||
println!("cargo:rustc-env=HELEXA_BUILD_TIMESTAMP={ts}");
|
||||
|
||||
// Compiler version: cargo sets $RUSTC to the rustc it invokes.
|
||||
let rustc = std::env::var("RUSTC").unwrap_or_else(|_| "rustc".to_string());
|
||||
let rustc_version = Command::new(rustc)
|
||||
.arg("--version")
|
||||
.output()
|
||||
.ok()
|
||||
.filter(|o| o.status.success())
|
||||
.map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
|
||||
.unwrap_or_default();
|
||||
println!("cargo:rustc-env=HELEXA_RUSTC_VERSION={rustc_version}");
|
||||
|
||||
println!(
|
||||
"cargo:rustc-env=HELEXA_BUILD_PROFILE={}",
|
||||
std::env::var("PROFILE").unwrap_or_default()
|
||||
);
|
||||
println!(
|
||||
"cargo:rustc-env=HELEXA_TARGET={}",
|
||||
std::env::var("TARGET").unwrap_or_default()
|
||||
);
|
||||
|
||||
// Enabled features: cargo exports CARGO_FEATURE_<NAME> for each.
|
||||
// Reverse the mangling (uppercase, '-'→'_') best-effort for display.
|
||||
let mut features: Vec<String> = std::env::vars()
|
||||
.filter_map(|(k, _)| k.strip_prefix("CARGO_FEATURE_").map(|f| f.to_string()))
|
||||
.map(|f| f.to_lowercase().replace('_', "-"))
|
||||
// `default` is the meta-feature, not a perf-relevant flag.
|
||||
.filter(|f| f != "default")
|
||||
.collect();
|
||||
features.sort();
|
||||
println!("cargo:rustc-env=HELEXA_FEATURES={}", features.join(","));
|
||||
|
||||
println!(
|
||||
"cargo:rustc-env=HELEXA_CANDLE_VERSION={}",
|
||||
candle_version().unwrap_or_default()
|
||||
);
|
||||
}
|
||||
|
||||
fn git(args: &[&str]) -> Option<String> {
|
||||
let out = Command::new("git").args(args).output().ok()?;
|
||||
if !out.status.success() {
|
||||
return None;
|
||||
}
|
||||
let s = String::from_utf8_lossy(&out.stdout).trim().to_string();
|
||||
if s.is_empty() { None } else { Some(s) }
|
||||
}
|
||||
|
||||
/// Best-effort: read the locked `candle-core` version from the workspace
|
||||
/// `Cargo.lock` (two levels up from this crate). Returns `None` if the
|
||||
/// lockfile is absent (e.g. some packaging flows) or the entry isn't
|
||||
/// found.
|
||||
fn candle_version() -> Option<String> {
|
||||
let manifest = std::env::var("CARGO_MANIFEST_DIR").ok()?;
|
||||
let lock = std::path::Path::new(&manifest)
|
||||
.join("..")
|
||||
.join("..")
|
||||
.join("Cargo.lock");
|
||||
println!("cargo:rerun-if-changed={}", lock.display());
|
||||
let text = std::fs::read_to_string(lock).ok()?;
|
||||
// Cargo.lock entries are `[[package]]\nname = "x"\nversion = "y"`.
|
||||
let mut in_candle = false;
|
||||
for line in text.lines() {
|
||||
let line = line.trim();
|
||||
if line == "[[package]]" {
|
||||
in_candle = false;
|
||||
} else if line == "name = \"candle-core\"" {
|
||||
in_candle = true;
|
||||
} else if in_candle && let Some(rest) = line.strip_prefix("version = \"") {
|
||||
return Some(rest.trim_end_matches('"').to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
//! Activation-time pre-warm progress tracking.
|
||||
//!
|
||||
//! Wraps the [`ActivationStatus`] snapshot in an async RwLock so the
|
||||
//! background pre-warm task can update it per-model while the
|
||||
//! `/health` handler reads coherent snapshots. The tracker exists
|
||||
//! because `default_models` loading moved from synchronous-before-bind
|
||||
//! to background-after-bind on 2026-05-26: the listener is up
|
||||
//! immediately, but `/health` now needs to tell callers which of the
|
||||
//! configured defaults are still warming.
|
||||
|
||||
use cortex_core::discovery::{ActivationState, ActivationStatus, PreWarmFailure};
|
||||
use cortex_core::harness::ModelSpec;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Shared, async-safe handle to the daemon's activation progress.
|
||||
///
|
||||
/// Construct once in `main` with the configured `default_models` so
|
||||
/// the initial `pending` list matches the spec; clone the `Arc` into
|
||||
/// the `NeuronState` for HTTP handlers and into the spawned pre-warm
|
||||
/// task for updates.
|
||||
pub struct ActivationTracker {
|
||||
inner: RwLock<ActivationStatus>,
|
||||
}
|
||||
|
||||
impl ActivationTracker {
|
||||
/// Build a tracker primed with one entry per spec. An empty spec
|
||||
/// list yields a `Ready` tracker — no point reporting PreWarming
|
||||
/// when there's nothing queued.
|
||||
pub fn new(default_models: &[ModelSpec]) -> Self {
|
||||
let pending: Vec<String> = default_models.iter().map(|s| s.model_id.clone()).collect();
|
||||
let state = if pending.is_empty() {
|
||||
ActivationState::Ready
|
||||
} else {
|
||||
ActivationState::PreWarming
|
||||
};
|
||||
Self {
|
||||
inner: RwLock::new(ActivationStatus {
|
||||
state,
|
||||
pending,
|
||||
in_progress: None,
|
||||
completed: vec![],
|
||||
failed: vec![],
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark a model as in-progress: remove it from `pending`, set as
|
||||
/// `in_progress`. Called immediately before `registry.load_model`.
|
||||
pub async fn start_loading(&self, model_id: &str) {
|
||||
let mut s = self.inner.write().await;
|
||||
s.pending.retain(|m| m != model_id);
|
||||
s.in_progress = Some(model_id.to_string());
|
||||
}
|
||||
|
||||
/// Mark a model as completed: clear `in_progress` (if it matches),
|
||||
/// append to `completed`.
|
||||
pub async fn complete_loading(&self, model_id: &str) {
|
||||
let mut s = self.inner.write().await;
|
||||
if s.in_progress.as_deref() == Some(model_id) {
|
||||
s.in_progress = None;
|
||||
}
|
||||
s.completed.push(model_id.to_string());
|
||||
}
|
||||
|
||||
/// Mark a model as failed: clear `in_progress` (if it matches),
|
||||
/// append a `PreWarmFailure` carrying the rendered error chain.
|
||||
pub async fn fail_loading(&self, model_id: &str, error: &str) {
|
||||
let mut s = self.inner.write().await;
|
||||
if s.in_progress.as_deref() == Some(model_id) {
|
||||
s.in_progress = None;
|
||||
}
|
||||
s.failed.push(PreWarmFailure {
|
||||
model_id: model_id.to_string(),
|
||||
error: error.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
/// Flip the high-level `state` to `Ready` once the pre-warm task
|
||||
/// is done iterating. Pending should be empty by this point; if a
|
||||
/// caller bails early it's a stuck activation and the operator
|
||||
/// will see entries in `pending` even with `state=ready` — that's
|
||||
/// a useful diagnostic, not an inconsistency to scrub.
|
||||
pub async fn mark_ready(&self) {
|
||||
let mut s = self.inner.write().await;
|
||||
s.state = ActivationState::Ready;
|
||||
s.in_progress = None;
|
||||
}
|
||||
|
||||
/// Cheap clone of the current state for the `/health` handler.
|
||||
pub async fn snapshot(&self) -> ActivationStatus {
|
||||
self.inner.read().await.clone()
|
||||
}
|
||||
}
|
||||
@@ -1,63 +1,34 @@
|
||||
//! HTTP API handlers for the neuron daemon.
|
||||
|
||||
use crate::activation::ActivationTracker;
|
||||
use crate::harness::HarnessRegistry;
|
||||
use crate::harness::candle::{CandleHarness, InferenceError};
|
||||
use crate::harness::preflight::PreflightError;
|
||||
use crate::health::HealthCache;
|
||||
use crate::wire::{openai_chat, openai_responses};
|
||||
use axum::Router;
|
||||
use axum::extract::{Path, State};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
use axum::response::{IntoResponse, Json};
|
||||
use axum::routing::{get, post};
|
||||
use cortex_core::discovery::{DiscoveryResponse, HealthResponse};
|
||||
use cortex_core::harness::ModelSpec;
|
||||
use cortex_core::openai::{ChatCompletionRequest, MessageContent};
|
||||
use cortex_core::responses::{ResponsesRequest, ResponsesUsage};
|
||||
use futures::stream::{self, StreamExt};
|
||||
use serde_json::{Value, json};
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
/// Shared state for the neuron HTTP server.
|
||||
pub struct NeuronState {
|
||||
pub discovery: DiscoveryResponse,
|
||||
pub health_cache: Arc<HealthCache>,
|
||||
pub registry: RwLock<HarnessRegistry>,
|
||||
/// Typed handle to the candle harness for inference routes. Cached at
|
||||
/// startup so `/v1/chat/completions` doesn't have to hold the registry
|
||||
/// read lock or perform dyn-Trait dispatch per request.
|
||||
pub candle: Option<Arc<CandleHarness>>,
|
||||
/// Activation-time pre-warm progress. Updated by the background
|
||||
/// `load_default_models` task, read by the `/health` handler.
|
||||
pub activation: Arc<ActivationTracker>,
|
||||
}
|
||||
|
||||
/// Build the neuron API router.
|
||||
pub fn neuron_routes() -> Router<Arc<NeuronState>> {
|
||||
Router::new()
|
||||
.route("/version", get(version_handler))
|
||||
.route("/discovery", get(discovery_handler))
|
||||
.route("/health", get(health_handler))
|
||||
.route("/models", get(list_models))
|
||||
.route("/models/load", post(load_model))
|
||||
.route("/models/unload", post(unload_model))
|
||||
.route("/models/{model_id}/endpoint", get(model_endpoint))
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/v1/responses", post(responses))
|
||||
}
|
||||
|
||||
/// `GET /version` — the daemon's own build identity (git SHA, enabled
|
||||
/// features, rustc/candle versions). Static for the process lifetime, so
|
||||
/// no state is touched. This is the canonical "which build is live"
|
||||
/// probe for fleet validation and benchmark attribution.
|
||||
async fn version_handler() -> Json<cortex_core::build_info::BuildInfo> {
|
||||
Json(crate::version::build_info())
|
||||
}
|
||||
|
||||
async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<DiscoveryResponse> {
|
||||
@@ -65,13 +36,7 @@ async fn discovery_handler(State(state): State<Arc<NeuronState>>) -> Json<Discov
|
||||
}
|
||||
|
||||
async fn health_handler(State(state): State<Arc<NeuronState>>) -> Json<HealthResponse> {
|
||||
// HealthCache owns the uptime + per-device readings; the activation
|
||||
// tracker owns the pre-warm progress. We compose the response here
|
||||
// so the cache stays a thin runtime-state cache and doesn't need to
|
||||
// know about activation lifecycle.
|
||||
let mut snapshot = state.health_cache.snapshot().await;
|
||||
snapshot.activation = state.activation.snapshot().await;
|
||||
Json(snapshot)
|
||||
Json(state.health_cache.snapshot().await)
|
||||
}
|
||||
|
||||
async fn list_models(State(state): State<Arc<NeuronState>>) -> impl IntoResponse {
|
||||
@@ -80,7 +45,7 @@ async fn list_models(State(state): State<Arc<NeuronState>>) -> impl IntoResponse
|
||||
Ok(models) => Json(json!(models)).into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("{e:#}")})),
|
||||
Json(json!({"error": e.to_string()})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
@@ -90,70 +55,14 @@ async fn load_model(
|
||||
State(state): State<Arc<NeuronState>>,
|
||||
Json(spec): Json<ModelSpec>,
|
||||
) -> impl IntoResponse {
|
||||
// Driver/library mismatch preflight (#19): every CUDA load is
|
||||
// guaranteed to fail until the host reboots. Reject up front with
|
||||
// the operator-actionable reason instead of letting the load die
|
||||
// minutes later inside cuInit/NCCL with a cryptic error.
|
||||
if let Some(reason) = &state.discovery.cuda_unavailable_reason {
|
||||
tracing::warn!(model = %spec.model_id, reason = %reason, "load_model rejected: CUDA unavailable");
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({
|
||||
"error": reason,
|
||||
"code": "cuda_unavailable",
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
let registry = state.registry.read().await;
|
||||
match registry.load_model(&spec).await {
|
||||
Ok(()) => Json(json!({"status": "loaded"})).into_response(),
|
||||
Err(e) => {
|
||||
// If the underlying failure is a structured preflight
|
||||
// rejection, surface it as 422 Unprocessable Entity with
|
||||
// the typed JSON body. The kind/model_id/suggestion/etc.
|
||||
// fields let cortex (and operators reading the response
|
||||
// directly) act on the failure without parsing free text.
|
||||
if let Some(pf) = e.downcast_ref::<PreflightError>() {
|
||||
tracing::warn!(
|
||||
model = %spec.model_id,
|
||||
reason = preflight_kind(pf),
|
||||
detail = %pf,
|
||||
"load_model rejected by preflight"
|
||||
);
|
||||
return (
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(json!({ "error": pf })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
// Log the full anyhow chain server-side so journalctl shows
|
||||
// the underlying failure (hf-hub timeout, permission denied,
|
||||
// disk full, etc.) without needing to inspect the HTTP
|
||||
// response body separately.
|
||||
tracing::warn!(
|
||||
model = %spec.model_id,
|
||||
error = %format!("{e:#}"),
|
||||
"load_model failed"
|
||||
);
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": format!("{e:#}")})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Short kebab-case tag for a preflight failure, used as a structured
|
||||
/// log field for journalctl-side filtering. Mirrors the same helper in
|
||||
/// `startup.rs`; duplicated to keep the module surfaces independent.
|
||||
fn preflight_kind(err: &PreflightError) -> &'static str {
|
||||
match err {
|
||||
PreflightError::RepoFetchFailed { .. } => "repo_fetch_failed",
|
||||
PreflightError::EmptyRepo { .. } => "empty_repo",
|
||||
PreflightError::TpRequiresSafetensors { .. } => "tp_requires_safetensors",
|
||||
PreflightError::QuantNotFound { .. } => "quant_not_found",
|
||||
Err(e) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": e.to_string()})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,11 +84,7 @@ async fn unload_model(
|
||||
let registry = state.registry.read().await;
|
||||
match registry.unload_model(&model_id).await {
|
||||
Ok(()) => Json(json!({"status": "unloaded"})).into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": format!("{e:#}")})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(e) => (StatusCode::NOT_FOUND, Json(json!({"error": e.to_string()}))).into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -197,347 +102,3 @@ async fn model_endpoint(
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// OpenAI-compatible chat completions. Dispatches to streaming SSE when
|
||||
/// `stream: true` is set on the request; otherwise returns a single
|
||||
/// `ChatCompletionResponse`.
|
||||
async fn chat_completions(
|
||||
State(state): State<Arc<NeuronState>>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(req): Json<ChatCompletionRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let Some(candle) = state.candle.as_ref().map(Arc::clone) else {
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({"error": "candle harness not enabled on this neuron"})),
|
||||
)
|
||||
.into_response();
|
||||
};
|
||||
|
||||
// Reasoning-content opt-in. Off by default → naïve clients
|
||||
// (Zed's commit-message generator, vanilla OpenAI clients)
|
||||
// never see `<think>` blocks. On when the caller sends
|
||||
// `x-include-thinking: true` (helexa-acp does this so its
|
||||
// own ThinkParser keeps working unchanged).
|
||||
let include_thinking = headers
|
||||
.get("x-include-thinking")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| matches!(s.trim().to_ascii_lowercase().as_str(), "1" | "true" | "yes"))
|
||||
.unwrap_or(false);
|
||||
let chat_config = openai_chat::ChatProjectionConfig {
|
||||
include_thinking,
|
||||
reasoning_markers: None, // filled in from the loaded model inside candle
|
||||
};
|
||||
|
||||
if req.stream.unwrap_or(false) {
|
||||
match candle.chat_completion_stream_with(req, chat_config).await {
|
||||
Ok(rx) => {
|
||||
// Each chunk → one SSE `data: {json}` line. After the
|
||||
// channel closes, append the OpenAI [DONE] terminator.
|
||||
let body_stream = ReceiverStream::new(rx).map(|chunk| {
|
||||
let body = serde_json::to_string(&chunk).unwrap_or_default();
|
||||
Ok::<_, Infallible>(Event::default().data(body))
|
||||
});
|
||||
let done_stream =
|
||||
stream::once(async { Ok::<_, Infallible>(Event::default().data("[DONE]")) });
|
||||
Sse::new(body_stream.chain(done_stream))
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
}
|
||||
Err(InferenceError::ModelNotLoaded(id)) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::PromptTooLong { prompt_len, max }) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
|
||||
"code": "prompt_too_long",
|
||||
"prompt_len": prompt_len,
|
||||
"max": max,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::InsufficientVram {
|
||||
free_mb,
|
||||
required_mb,
|
||||
}) => (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({
|
||||
"error": format!(
|
||||
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
|
||||
),
|
||||
"code": "insufficient_vram",
|
||||
"free_mb": free_mb,
|
||||
"required_mb": required_mb,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::VisionUnsupported { model_id }) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": format!(
|
||||
"model '{model_id}' does not support image input"
|
||||
),
|
||||
"code": "vision_unsupported",
|
||||
"model_id": model_id,
|
||||
"suggestion": "load a vision-capable model or remove image_url content parts",
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::Other(e)) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("{e:#}")})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
} else {
|
||||
match candle.chat_completion(req).await {
|
||||
Ok(resp) => Json(resp).into_response(),
|
||||
Err(InferenceError::ModelNotLoaded(id)) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::PromptTooLong { prompt_len, max }) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
|
||||
"code": "prompt_too_long",
|
||||
"prompt_len": prompt_len,
|
||||
"max": max,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::InsufficientVram {
|
||||
free_mb,
|
||||
required_mb,
|
||||
}) => (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({
|
||||
"error": format!(
|
||||
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
|
||||
),
|
||||
"code": "insufficient_vram",
|
||||
"free_mb": free_mb,
|
||||
"required_mb": required_mb,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::VisionUnsupported { model_id }) => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": format!(
|
||||
"model '{model_id}' does not support image input"
|
||||
),
|
||||
"code": "vision_unsupported",
|
||||
"model_id": model_id,
|
||||
"suggestion": "load a vision-capable model or remove image_url content parts",
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(InferenceError::Other(e)) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("{e:#}")})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// OpenAI Responses API (`POST /v1/responses`). Translates the
|
||||
/// Responses-shaped request into a chat-completions one the candle
|
||||
/// harness already understands, then re-projects the harness's
|
||||
/// event stream into the Responses event family.
|
||||
async fn responses(
|
||||
State(state): State<Arc<NeuronState>>,
|
||||
Json(req): Json<ResponsesRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let Some(candle) = state.candle.as_ref().map(Arc::clone) else {
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({"error": "candle harness not enabled on this neuron"})),
|
||||
)
|
||||
.into_response();
|
||||
};
|
||||
|
||||
let stream_requested = req.stream;
|
||||
let model_id = req.model.clone();
|
||||
let response_id = mint_response_id();
|
||||
let message_item_id = mint_message_item_id();
|
||||
|
||||
// Translate Responses → chat completions. The only failure
|
||||
// mode today is `previous_response_id` set, which we reject
|
||||
// with 400 — stateful conversations need a persistence layer
|
||||
// we haven't built.
|
||||
let mut chat_req = match openai_responses::request_to_chat(req) {
|
||||
Ok(r) => r,
|
||||
Err(openai_responses::TranslateError::ChainedConversationNotSupported) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": "previous_response_id is not supported on this neuron",
|
||||
"code": "chained_conversation_not_supported"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
chat_req.stream = Some(stream_requested);
|
||||
|
||||
if stream_requested {
|
||||
match candle
|
||||
.responses_stream(chat_req, response_id, message_item_id)
|
||||
.await
|
||||
{
|
||||
Ok(rx) => {
|
||||
// Each ResponseStreamFrame → one SSE event carrying
|
||||
// both an event name and JSON data. The Responses
|
||||
// API doesn't use a `[DONE]` terminator — clients
|
||||
// see the `response.completed` event as the end of
|
||||
// the stream.
|
||||
let body_stream = ReceiverStream::new(rx).map(|frame| {
|
||||
let body = serde_json::to_string(&frame.data).unwrap_or_else(|_| "{}".into());
|
||||
Ok::<_, Infallible>(Event::default().event(frame.event_name).data(body))
|
||||
});
|
||||
Sse::new(body_stream)
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
}
|
||||
Err(e) => inference_error_response(e),
|
||||
}
|
||||
} else {
|
||||
// Non-streaming: drive the existing chat completion path
|
||||
// and translate the result. We don't currently re-tokenise
|
||||
// to compute usage; the harness returns it via the chat
|
||||
// response and we pass it through.
|
||||
match candle.chat_completion(chat_req).await {
|
||||
Ok(chat_resp) => {
|
||||
// Extract the assistant text (chat completions
|
||||
// always emits one choice on the candle path).
|
||||
let text = chat_resp
|
||||
.choices
|
||||
.first()
|
||||
.map(|c| match &c.message.content {
|
||||
MessageContent::Text(t) => t.clone(),
|
||||
MessageContent::Parts(_) => {
|
||||
// Candle output is always text today;
|
||||
// a Parts response would be surprising.
|
||||
// Empty-string fallback is safer than
|
||||
// a panic.
|
||||
String::new()
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
let finish = chat_resp
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|c| c.finish_reason.as_deref())
|
||||
.map(finish_reason_from_str)
|
||||
.unwrap_or(crate::wire::FinishReason::Stop);
|
||||
let usage = chat_resp.usage.as_ref().map(|u| ResponsesUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
total_tokens: u.prompt_tokens + u.completion_tokens,
|
||||
});
|
||||
let meta = openai_responses::ResponseMeta {
|
||||
response_id: mint_response_id(),
|
||||
created_at: unix_now_secs(),
|
||||
model_id,
|
||||
message_item_id: mint_message_item_id(),
|
||||
};
|
||||
let _ = chat_resp; // make the borrow-checker happy if `text` consumed it
|
||||
let resp = openai_responses::build_response(&meta, text, finish, usage);
|
||||
Json(resp).into_response()
|
||||
}
|
||||
Err(e) => inference_error_response(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn finish_reason_from_str(s: &str) -> crate::wire::FinishReason {
|
||||
use crate::wire::FinishReason;
|
||||
match s {
|
||||
"length" => FinishReason::Length,
|
||||
"tool_calls" => FinishReason::ToolCalls,
|
||||
_ => FinishReason::Stop,
|
||||
}
|
||||
}
|
||||
|
||||
/// Centralised mapping from [`InferenceError`] to an HTTP response.
|
||||
/// Lifted out so the chat-completions and responses handlers stay
|
||||
/// readable and changes to error-code semantics happen in one spot.
|
||||
fn inference_error_response(err: InferenceError) -> axum::response::Response {
|
||||
match err {
|
||||
InferenceError::ModelNotLoaded(id) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": format!("model '{id}' not loaded on this neuron")})),
|
||||
)
|
||||
.into_response(),
|
||||
InferenceError::PromptTooLong { prompt_len, max } => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": format!("prompt has {prompt_len} tokens but max is {max}"),
|
||||
"code": "prompt_too_long",
|
||||
"prompt_len": prompt_len,
|
||||
"max": max,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
InferenceError::InsufficientVram {
|
||||
free_mb,
|
||||
required_mb,
|
||||
} => (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({
|
||||
"error": format!(
|
||||
"insufficient free VRAM: {free_mb} MiB free, need at least {required_mb} MiB"
|
||||
),
|
||||
"code": "insufficient_vram",
|
||||
"free_mb": free_mb,
|
||||
"required_mb": required_mb,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
InferenceError::VisionUnsupported { model_id } => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": format!(
|
||||
"model '{model_id}' does not support image input"
|
||||
),
|
||||
"code": "vision_unsupported",
|
||||
"model_id": model_id,
|
||||
"suggestion": "load a vision-capable model or remove image_url content parts",
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
InferenceError::Other(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("{e:#}")})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
fn mint_response_id() -> String {
|
||||
format!("resp_{:x}", unix_subsec_nanos())
|
||||
}
|
||||
|
||||
fn mint_message_item_id() -> String {
|
||||
format!("msg_{:x}", unix_subsec_nanos())
|
||||
}
|
||||
|
||||
fn unix_now_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
fn unix_subsec_nanos() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|d| d.as_nanos() as u64)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
@@ -1,22 +1,12 @@
|
||||
//! Neuron configuration loaded from neuron.toml.
|
||||
|
||||
use cortex_core::harness::{HarnessConfig, ModelSpec};
|
||||
use cortex_core::harness::HarnessConfig;
|
||||
use figment::{
|
||||
Figment,
|
||||
providers::{Env, Format, Toml},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Default scheme name applied to bare `org/name` model ids when no
|
||||
/// `[harness.candle.default_source]` is set. Keeps existing operator
|
||||
/// configs (which know nothing about schemes) working unchanged.
|
||||
pub const DEFAULT_SOURCE_SCHEME: &str = "huggingface";
|
||||
|
||||
/// Endpoint URL for the default huggingface source, used when no
|
||||
/// `[harness.candle.sources.huggingface]` is configured.
|
||||
pub const DEFAULT_HF_ENDPOINT: &str = "https://huggingface.co";
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NeuronConfig {
|
||||
@@ -24,156 +14,6 @@ pub struct NeuronConfig {
|
||||
pub port: u16,
|
||||
#[serde(default)]
|
||||
pub harnesses: Vec<HarnessConfig>,
|
||||
/// Per-harness configuration. Currently only `candle` is recognised.
|
||||
#[serde(default)]
|
||||
pub harness: HarnessSettings,
|
||||
/// Models to auto-load when the neuron service activates. Each entry
|
||||
/// is loaded sequentially before the HTTP listener binds. A failure
|
||||
/// on any single entry logs a warning and proceeds — broken entries
|
||||
/// don't prevent the rest of the fleet from starting.
|
||||
#[serde(default)]
|
||||
pub default_models: Vec<ModelSpec>,
|
||||
}
|
||||
|
||||
/// Settings for individual harness implementations. Each harness owns
|
||||
/// its own sub-table so users only configure the harnesses they enable.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct HarnessSettings {
|
||||
#[serde(default)]
|
||||
pub candle: CandleHarnessConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct CandleHarnessConfig {
|
||||
/// HuggingFace cache directory for model weights.
|
||||
/// When unset, defers to hf-hub's default (~/.cache/huggingface).
|
||||
///
|
||||
/// Retained for back-compat — operators with existing
|
||||
/// `hf_cache = "..."` configs continue to work. Treated as the
|
||||
/// `huggingface` source's cache_dir when a sources table isn't
|
||||
/// provided.
|
||||
#[serde(default)]
|
||||
pub hf_cache: Option<PathBuf>,
|
||||
|
||||
/// Default source scheme applied to bare `org/name` model ids
|
||||
/// (those without an explicit `scheme:` prefix). When unset, falls
|
||||
/// back to `DEFAULT_SOURCE_SCHEME` ("huggingface").
|
||||
#[serde(default)]
|
||||
pub default_source: Option<String>,
|
||||
|
||||
/// Per-scheme source endpoints. Each entry maps a scheme name
|
||||
/// (`huggingface`, `helexa`, an operator's mirror tag, …) to its
|
||||
/// endpoint URL, optional auth env var, and optional cache
|
||||
/// directory.
|
||||
///
|
||||
/// When absent or missing the `huggingface` key, the loader
|
||||
/// synthesises a `huggingface` entry pointing at
|
||||
/// `https://huggingface.co` with `hf_cache` (above) as its
|
||||
/// cache_dir. This keeps single-source configs ergonomic.
|
||||
#[serde(default)]
|
||||
pub sources: HashMap<String, SourceConfig>,
|
||||
|
||||
/// Prefix KV cache across requests (#11). Applies per loaded
|
||||
/// model, on architectures that support cache snapshots (qwen3_5).
|
||||
#[serde(default)]
|
||||
pub prefix_cache: PrefixCacheConfig,
|
||||
}
|
||||
|
||||
/// `[harness.candle.prefix_cache]` settings.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PrefixCacheConfig {
|
||||
/// Master switch. On by default — set `false` to restore the
|
||||
/// clear-every-request behaviour.
|
||||
#[serde(default = "default_prefix_cache_enabled")]
|
||||
pub enabled: bool,
|
||||
/// Snapshot byte budget per loaded model, in MiB. Snapshots live
|
||||
/// on the model's device, so this comes out of the same VRAM that
|
||||
/// serves inference — size it against the device's headroom after
|
||||
/// the model weights.
|
||||
#[serde(default = "default_prefix_cache_budget_mb")]
|
||||
pub budget_mb: u64,
|
||||
/// Maximum live snapshots per loaded model, regardless of budget.
|
||||
#[serde(default = "default_prefix_cache_max_entries")]
|
||||
pub max_entries: usize,
|
||||
}
|
||||
|
||||
impl Default for PrefixCacheConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: default_prefix_cache_enabled(),
|
||||
budget_mb: default_prefix_cache_budget_mb(),
|
||||
max_entries: default_prefix_cache_max_entries(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_prefix_cache_enabled() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_prefix_cache_budget_mb() -> u64 {
|
||||
1024
|
||||
}
|
||||
|
||||
fn default_prefix_cache_max_entries() -> usize {
|
||||
8
|
||||
}
|
||||
|
||||
/// Per-scheme source configuration. Mirrors the shape `hf_hub::ApiBuilder`
|
||||
/// needs: endpoint URL, optional auth token (read from an env var so
|
||||
/// secrets stay out of the config file), and optional cache directory
|
||||
/// disambiguated per source to prevent mirror-vs-canonical collisions.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct SourceConfig {
|
||||
/// Base URL of the registry. Must speak the HF-compatible wire
|
||||
/// format (siblings listing at
|
||||
/// `/api/models/{org}/{name}[/revision/{rev}]`, blob fetch at
|
||||
/// `/{org}/{name}/resolve/{rev}/{path}`).
|
||||
pub endpoint: String,
|
||||
|
||||
/// Environment variable name to read for the bearer token used
|
||||
/// against this source. `None` = anonymous. Reading from env
|
||||
/// (vs. literal token in the config) keeps secrets out of TOML.
|
||||
#[serde(default)]
|
||||
pub auth_env: Option<String>,
|
||||
|
||||
/// Cache directory for this source. The hf-hub
|
||||
/// `models--{org}--{name}/snapshots/...` tree lives directly
|
||||
/// under this path, so distinct sources serving the same
|
||||
/// `org/name` cannot collide on disk.
|
||||
///
|
||||
/// `None` means "share the harness `hf_cache` directory" — only
|
||||
/// safe when the operator has exactly one source configured.
|
||||
#[serde(default)]
|
||||
pub cache_dir: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl CandleHarnessConfig {
|
||||
/// Resolve the effective sources map for this config, synthesising
|
||||
/// a `huggingface` entry from legacy fields (`hf_cache`) when the
|
||||
/// operator hasn't supplied a sources table. Idempotent.
|
||||
///
|
||||
/// Returns a fresh map rather than mutating self so the original
|
||||
/// (operator-typed) config can still be serialized back to TOML
|
||||
/// for diagnostics.
|
||||
pub fn effective_sources(&self) -> HashMap<String, SourceConfig> {
|
||||
let mut out = self.sources.clone();
|
||||
out.entry(DEFAULT_SOURCE_SCHEME.to_string())
|
||||
.or_insert_with(|| SourceConfig {
|
||||
endpoint: DEFAULT_HF_ENDPOINT.to_string(),
|
||||
auth_env: Some("HF_TOKEN".to_string()),
|
||||
cache_dir: self.hf_cache.clone(),
|
||||
});
|
||||
out
|
||||
}
|
||||
|
||||
/// Effective default scheme. Falls back to `DEFAULT_SOURCE_SCHEME`
|
||||
/// when the operator hasn't pinned one.
|
||||
pub fn effective_default_source(&self) -> &str {
|
||||
self.default_source
|
||||
.as_deref()
|
||||
.unwrap_or(DEFAULT_SOURCE_SCHEME)
|
||||
}
|
||||
}
|
||||
|
||||
fn default_port() -> u16 {
|
||||
@@ -195,114 +35,6 @@ impl Default for NeuronConfig {
|
||||
Self {
|
||||
port: 13131,
|
||||
harnesses: vec![],
|
||||
harness: HarnessSettings::default(),
|
||||
default_models: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn effective_sources_synthesises_huggingface_when_absent() {
|
||||
let cfg = CandleHarnessConfig::default();
|
||||
let sources = cfg.effective_sources();
|
||||
assert!(sources.contains_key("huggingface"));
|
||||
let hf = &sources["huggingface"];
|
||||
assert_eq!(hf.endpoint, DEFAULT_HF_ENDPOINT);
|
||||
assert_eq!(hf.auth_env.as_deref(), Some("HF_TOKEN"));
|
||||
assert!(hf.cache_dir.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_sources_carries_legacy_hf_cache_into_synth_entry() {
|
||||
// Existing operator configs only set `hf_cache = "/archive3/..."`
|
||||
// — the synth must pick that up so the loader keeps using the
|
||||
// operator's storage.
|
||||
let cfg = CandleHarnessConfig {
|
||||
hf_cache: Some(PathBuf::from("/archive3/llm-cache")),
|
||||
..Default::default()
|
||||
};
|
||||
let sources = cfg.effective_sources();
|
||||
assert_eq!(
|
||||
sources["huggingface"].cache_dir.as_deref(),
|
||||
Some(Path::new("/archive3/llm-cache"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_sources_preserves_explicit_huggingface_entry() {
|
||||
// When an operator types out `[harness.candle.sources.huggingface]`
|
||||
// explicitly, we must not clobber it with the synth defaults.
|
||||
let mut sources = HashMap::new();
|
||||
sources.insert(
|
||||
"huggingface".to_string(),
|
||||
SourceConfig {
|
||||
endpoint: "https://huggingface.example.org".into(),
|
||||
auth_env: Some("MY_TOKEN".into()),
|
||||
cache_dir: Some(PathBuf::from("/operator-cache")),
|
||||
},
|
||||
);
|
||||
let cfg = CandleHarnessConfig {
|
||||
hf_cache: Some(PathBuf::from("/legacy-cache")),
|
||||
sources,
|
||||
..Default::default()
|
||||
};
|
||||
let effective = cfg.effective_sources();
|
||||
assert_eq!(
|
||||
effective["huggingface"].endpoint,
|
||||
"https://huggingface.example.org"
|
||||
);
|
||||
assert_eq!(
|
||||
effective["huggingface"].auth_env.as_deref(),
|
||||
Some("MY_TOKEN")
|
||||
);
|
||||
assert_eq!(
|
||||
effective["huggingface"].cache_dir.as_deref(),
|
||||
Some(Path::new("/operator-cache"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_sources_includes_helexa_alongside_synth_huggingface() {
|
||||
let mut sources = HashMap::new();
|
||||
sources.insert(
|
||||
"helexa".to_string(),
|
||||
SourceConfig {
|
||||
endpoint: "https://registry.helexa.ai".into(),
|
||||
auth_env: Some("HELEXA_TOKEN".into()),
|
||||
cache_dir: Some(PathBuf::from("/archive3/llm-cache/helexa")),
|
||||
},
|
||||
);
|
||||
let cfg = CandleHarnessConfig {
|
||||
hf_cache: Some(PathBuf::from("/archive3/llm-cache/huggingface")),
|
||||
sources,
|
||||
..Default::default()
|
||||
};
|
||||
let effective = cfg.effective_sources();
|
||||
assert_eq!(effective.len(), 2);
|
||||
assert_eq!(effective["helexa"].endpoint, "https://registry.helexa.ai");
|
||||
// huggingface still gets synth-derived from legacy hf_cache.
|
||||
assert_eq!(
|
||||
effective["huggingface"].cache_dir.as_deref(),
|
||||
Some(Path::new("/archive3/llm-cache/huggingface"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_default_source_falls_back() {
|
||||
let cfg = CandleHarnessConfig::default();
|
||||
assert_eq!(cfg.effective_default_source(), DEFAULT_SOURCE_SCHEME);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_default_source_honours_explicit() {
|
||||
let cfg = CandleHarnessConfig {
|
||||
default_source: Some("helexa".into()),
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(cfg.effective_default_source(), "helexa");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
//! FFI declarations for the CUDA kernels in `gdn.cu`.
|
||||
//!
|
||||
//! Subset of `EricLBuehler/mistral.rs::mistralrs-core/src/cuda/ffi.rs`
|
||||
//! covering only the Gated DeltaNet kernels we currently use. Other
|
||||
//! kernels in the upstream file (MoE GEMM, top-k, Mamba selective
|
||||
//! scan, etc.) would land here too as we absorb them.
|
||||
//!
|
||||
//! All function declarations are MIT-licensed from upstream and
|
||||
//! unchanged apart from this header.
|
||||
|
||||
use std::ffi::c_void;
|
||||
|
||||
#[allow(dead_code)]
|
||||
unsafe extern "C" {
|
||||
// GDN (Gated Delta Net) kernels for qwen3_5 / Qwen3-Next.
|
||||
pub(crate) fn gated_delta_rule_recurrence(
|
||||
q: *const f32,
|
||||
k: *const f32,
|
||||
v: *const f32,
|
||||
g: *const f32,
|
||||
beta: *const f32,
|
||||
state: *mut f32,
|
||||
output: *mut f32,
|
||||
bh: i32,
|
||||
seq_len: i32,
|
||||
k_dim: i32,
|
||||
v_dim: i32,
|
||||
stream: i64,
|
||||
);
|
||||
|
||||
/// Chunked GDN recurrence for prefill (processes tokens in BT=64 chunks).
|
||||
pub(crate) fn chunked_gated_delta_rule_recurrence(
|
||||
q: *const f32,
|
||||
k: *const f32,
|
||||
v: *const f32,
|
||||
g: *const f32,
|
||||
beta: *const f32,
|
||||
state: *mut f32,
|
||||
output: *mut f32,
|
||||
bh: i32,
|
||||
seq_len: i32,
|
||||
k_dim: i32,
|
||||
v_dim: i32,
|
||||
stream: i64,
|
||||
);
|
||||
|
||||
pub(crate) fn causal_conv1d_update(
|
||||
x: *const c_void,
|
||||
weight: *const c_void,
|
||||
conv_state: *mut c_void,
|
||||
output: *mut c_void,
|
||||
batch_size: i32,
|
||||
conv_dim: i32,
|
||||
kernel_size: i32,
|
||||
dtype: i32,
|
||||
stream: i64,
|
||||
);
|
||||
|
||||
pub(crate) fn causal_conv1d_full(
|
||||
x: *const c_void,
|
||||
weight: *const c_void,
|
||||
conv_state_out: *mut c_void,
|
||||
output: *mut c_void,
|
||||
batch_size: i32,
|
||||
conv_dim: i32,
|
||||
seq_len: i32,
|
||||
kernel_size: i32,
|
||||
dtype: i32,
|
||||
stream: i64,
|
||||
);
|
||||
|
||||
pub(crate) fn fused_gdn_gating(
|
||||
b: *const c_void,
|
||||
a: *const c_void,
|
||||
a_log: *const f32,
|
||||
dt_bias: *const f32,
|
||||
beta_out: *mut c_void,
|
||||
g_out: *mut c_void,
|
||||
total_elements: i32,
|
||||
num_heads: i32,
|
||||
dtype: i32,
|
||||
stream: i64,
|
||||
);
|
||||
}
|
||||
@@ -1,711 +0,0 @@
|
||||
// Gated DeltaNet CUDA kernels for Qwen3-Next (`model_type = "qwen3_5"`).
|
||||
//
|
||||
// Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
|
||||
// Upstream path: mistralrs-core/src/cuda/gdn.cu. Local edits in this
|
||||
// file are limited to this banner; the kernels are unchanged so a
|
||||
// diff against upstream stays minimal.
|
||||
//
|
||||
// Five kernels exposed via `extern "C"` shims at the bottom:
|
||||
// - gated_delta_rule_recurrence (per-token decode)
|
||||
// - chunked_gated_delta_rule_recurrence (BT=64 chunked prefill)
|
||||
// - causal_conv1d_update (single-token conv decode)
|
||||
// - causal_conv1d_full (multi-token conv prefill)
|
||||
// - fused_gdn_gating (beta = sigmoid(b);
|
||||
// g = -exp(A_log) * softplus(a + dt_bias))
|
||||
|
||||
#include "cuda_bf16.h"
|
||||
#include "cuda_fp16.h"
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
// ============================================================================
|
||||
// Kernel 1: gated_delta_rule_recurrence (optimized)
|
||||
//
|
||||
// V-tiled recurrence with compile-time K dimension for register residency.
|
||||
// Grid: (ceil(V/BV), B*H), Block: (BV,). Each thread owns BK registers of
|
||||
// state. Shared memory holds k_buf and q_buf (2*BK floats).
|
||||
//
|
||||
// Optimizations over naive version:
|
||||
// - Template BK -> float s[BK] lives in true registers (1 cycle vs ~30)
|
||||
// - #pragma unroll on all k-loops -> full ILP
|
||||
// - Fused decay+kv_mem pass and fused state_update+output pass
|
||||
// - __fmaf_rn intrinsics for guaranteed fused multiply-add
|
||||
// - BV=64 threads -> 2 warps, 6 blocks/SM on Ampere
|
||||
//
|
||||
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
|
||||
// state: [BH, K, V] (in/out) output: [BH, S, V]
|
||||
// ============================================================================
|
||||
|
||||
// Optimized kernel: BK known at compile time -> registers + full unrolling
|
||||
template <int BK, int BV>
|
||||
__global__ void gated_delta_rule_recurrence_kernel_tiled(
|
||||
const float *__restrict__ q, // [BH, S, K]
|
||||
const float *__restrict__ k, // [BH, S, K]
|
||||
const float *__restrict__ v, // [BH, S, V]
|
||||
const float *__restrict__ g, // [BH, S]
|
||||
const float *__restrict__ beta, // [BH, S]
|
||||
float *__restrict__ state, // [BH, K, V]
|
||||
float *__restrict__ output, // [BH, S, V]
|
||||
int seq_len, int v_dim) {
|
||||
|
||||
const int v_tile = blockIdx.x; // which V-tile
|
||||
const int bh = blockIdx.y; // batch*head index
|
||||
const int tid = threadIdx.x; // thread within tile [0, BV)
|
||||
const int v_idx = v_tile * BV + tid; // global V index
|
||||
|
||||
if (v_idx >= v_dim)
|
||||
return;
|
||||
|
||||
// Pointers for this (batch, head)
|
||||
const float *q_bh = q + bh * seq_len * BK;
|
||||
const float *k_bh = k + bh * seq_len * BK;
|
||||
const float *v_bh = v + bh * seq_len * v_dim;
|
||||
const float *g_bh = g + bh * seq_len;
|
||||
const float *beta_bh = beta + bh * seq_len;
|
||||
float *state_bh = state + bh * BK * v_dim;
|
||||
float *out_bh = output + bh * seq_len * v_dim;
|
||||
|
||||
// Shared memory: k_buf[BK] + q_buf[BK]
|
||||
__shared__ float k_buf[BK];
|
||||
__shared__ float q_buf[BK];
|
||||
|
||||
// Load state column into registers — BK is compile-time, so this is
|
||||
// a true register array (not spilled to local memory)
|
||||
float s[BK];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BK; j++) {
|
||||
s[j] = state_bh[j * v_dim + v_idx];
|
||||
}
|
||||
|
||||
for (int t = 0; t < seq_len; t++) {
|
||||
// Collaboratively load k_t into shared memory
|
||||
// BK / BV loads per thread (e.g. 128/64 = 2)
|
||||
#pragma unroll
|
||||
for (int j = tid; j < BK; j += BV) {
|
||||
k_buf[j] = k_bh[t * BK + j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Load scalars for this timestep
|
||||
float decay = expf(g_bh[t]);
|
||||
float beta_t = beta_bh[t];
|
||||
float v_t = v_bh[t * v_dim + v_idx];
|
||||
|
||||
// Fused pass 1: decay state + compute kv_mem
|
||||
float kv_mem = 0.0f;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BK; j++) {
|
||||
s[j] *= decay;
|
||||
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
|
||||
}
|
||||
|
||||
// Delta rule
|
||||
float delta = (v_t - kv_mem) * beta_t;
|
||||
|
||||
// Collaboratively load q_t into shared memory
|
||||
#pragma unroll
|
||||
for (int j = tid; j < BK; j += BV) {
|
||||
q_buf[j] = q_bh[t * BK + j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Fused pass 2: update state + compute output
|
||||
float y_t = 0.0f;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BK; j++) {
|
||||
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
|
||||
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
|
||||
}
|
||||
|
||||
out_bh[t * v_dim + v_idx] = y_t;
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Write state back
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BK; j++) {
|
||||
state_bh[j * v_dim + v_idx] = s[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback kernel: runtime k_dim, still V-tiled for occupancy
|
||||
template <int BV, int MAX_K>
|
||||
__global__ void gated_delta_rule_recurrence_kernel_fallback(
|
||||
const float *__restrict__ q, const float *__restrict__ k,
|
||||
const float *__restrict__ v, const float *__restrict__ g,
|
||||
const float *__restrict__ beta, float *__restrict__ state,
|
||||
float *__restrict__ output, int seq_len, int k_dim, int v_dim) {
|
||||
|
||||
const int v_tile = blockIdx.x;
|
||||
const int bh = blockIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
const int v_idx = v_tile * BV + tid;
|
||||
|
||||
if (v_idx >= v_dim)
|
||||
return;
|
||||
|
||||
const float *q_bh = q + bh * seq_len * k_dim;
|
||||
const float *k_bh = k + bh * seq_len * k_dim;
|
||||
const float *v_bh = v + bh * seq_len * v_dim;
|
||||
const float *g_bh = g + bh * seq_len;
|
||||
const float *beta_bh = beta + bh * seq_len;
|
||||
float *state_bh = state + bh * k_dim * v_dim;
|
||||
float *out_bh = output + bh * seq_len * v_dim;
|
||||
|
||||
extern __shared__ float shared[];
|
||||
float *k_buf = shared;
|
||||
float *q_buf = shared + k_dim;
|
||||
|
||||
float s[MAX_K];
|
||||
for (int j = 0; j < k_dim; j++) {
|
||||
s[j] = state_bh[j * v_dim + v_idx];
|
||||
}
|
||||
|
||||
for (int t = 0; t < seq_len; t++) {
|
||||
for (int j = tid; j < k_dim; j += BV) {
|
||||
k_buf[j] = k_bh[t * k_dim + j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float decay = expf(g_bh[t]);
|
||||
float beta_t = beta_bh[t];
|
||||
float v_t = v_bh[t * v_dim + v_idx];
|
||||
|
||||
float kv_mem = 0.0f;
|
||||
for (int j = 0; j < k_dim; j++) {
|
||||
s[j] *= decay;
|
||||
kv_mem = __fmaf_rn(s[j], k_buf[j], kv_mem);
|
||||
}
|
||||
|
||||
float delta = (v_t - kv_mem) * beta_t;
|
||||
|
||||
for (int j = tid; j < k_dim; j += BV) {
|
||||
q_buf[j] = q_bh[t * k_dim + j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float y_t = 0.0f;
|
||||
for (int j = 0; j < k_dim; j++) {
|
||||
s[j] = __fmaf_rn(k_buf[j], delta, s[j]);
|
||||
y_t = __fmaf_rn(s[j], q_buf[j], y_t);
|
||||
}
|
||||
|
||||
out_bh[t * v_dim + v_idx] = y_t;
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
for (int j = 0; j < k_dim; j++) {
|
||||
state_bh[j * v_dim + v_idx] = s[j];
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void gated_delta_rule_recurrence(const float *q, const float *k,
|
||||
const float *v, const float *g,
|
||||
const float *beta, float *state,
|
||||
float *output, int bh, int seq_len,
|
||||
int k_dim, int v_dim,
|
||||
int64_t stream) {
|
||||
|
||||
const cudaStream_t custream = (cudaStream_t)stream;
|
||||
|
||||
if (k_dim == 128) {
|
||||
// Fast path for Qwen3-Next (k_dim=128)
|
||||
constexpr int BK = 128;
|
||||
constexpr int BV = 64;
|
||||
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||
dim3 block(BV);
|
||||
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
|
||||
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
|
||||
v_dim);
|
||||
} else if (k_dim == 64) {
|
||||
// Fast path for models with k_dim=64
|
||||
constexpr int BK = 64;
|
||||
constexpr int BV = 64;
|
||||
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||
dim3 block(BV);
|
||||
gated_delta_rule_recurrence_kernel_tiled<BK, BV>
|
||||
<<<grid, block, 0, custream>>>(q, k, v, g, beta, state, output, seq_len,
|
||||
v_dim);
|
||||
} else {
|
||||
// Fallback for other k_dim values (runtime loop, still V-tiled)
|
||||
constexpr int BV = 64;
|
||||
constexpr int MAX_K = 256;
|
||||
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||
dim3 block(BV);
|
||||
size_t smem = 2 * k_dim * sizeof(float);
|
||||
gated_delta_rule_recurrence_kernel_fallback<BV, MAX_K>
|
||||
<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
|
||||
seq_len, k_dim, v_dim);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Kernel 1b: chunked_gated_delta_rule_recurrence (prefill optimization)
|
||||
//
|
||||
// Processes prefill tokens in BT-token chunks instead of one at a time.
|
||||
// Within each chunk: parallel prefix sum of g, cooperative kk_dot computation,
|
||||
// forward substitution (triangular solve), output computation, and state
|
||||
// update.
|
||||
//
|
||||
// Same thread model as Kernel 1: one block per (v_tile, batch*head),
|
||||
// one thread per V-column. Each thread owns BK registers of state.
|
||||
//
|
||||
// Shared memory holds:
|
||||
// k_chunk[BT * BK] -- key vectors for current chunk
|
||||
// kk_dot[BT * BT] -- dot(k[i], k[j]) lower-triangular matrix
|
||||
// gcum[BT] -- cumulative sum of g within chunk
|
||||
// beta_s[BT] -- beta values for chunk
|
||||
// q_buf[BK] -- q vector (loaded one row at a time)
|
||||
//
|
||||
// q,k: [BH, S, K] v: [BH, S, V] g,beta: [BH, S]
|
||||
// state: [BH, K, V] (in/out) output: [BH, S, V]
|
||||
// ============================================================================
|
||||
|
||||
template <int BT, int BK, int BV>
|
||||
__global__ void
|
||||
chunked_gated_delta_rule_kernel(const float *__restrict__ q, // [BH, S, K]
|
||||
const float *__restrict__ k, // [BH, S, K]
|
||||
const float *__restrict__ v, // [BH, S, V]
|
||||
const float *__restrict__ g, // [BH, S]
|
||||
const float *__restrict__ beta, // [BH, S]
|
||||
float *__restrict__ state, // [BH, K, V]
|
||||
float *__restrict__ output, // [BH, S, V]
|
||||
int seq_len, int v_dim) {
|
||||
|
||||
const int v_tile = blockIdx.x;
|
||||
const int bh = blockIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
const int v_idx = v_tile * BV + tid;
|
||||
|
||||
if (v_idx >= v_dim)
|
||||
return;
|
||||
|
||||
const int num_chunks = (seq_len + BT - 1) / BT;
|
||||
|
||||
// Pointers for this (batch, head)
|
||||
const float *q_bh = q + bh * seq_len * BK;
|
||||
const float *k_bh = k + bh * seq_len * BK;
|
||||
const float *v_bh = v + bh * seq_len * v_dim;
|
||||
const float *g_bh = g + bh * seq_len;
|
||||
const float *beta_bh = beta + bh * seq_len;
|
||||
float *state_bh = state + bh * BK * v_dim;
|
||||
float *out_bh = output + bh * seq_len * v_dim;
|
||||
|
||||
// Dynamic shared memory layout
|
||||
extern __shared__ float smem[];
|
||||
float *k_chunk = smem; // [BT * BK]
|
||||
float *kk_dot = smem + BT * BK; // [BT * BT]
|
||||
float *gcum = smem + BT * BK + BT * BT; // [BT]
|
||||
float *beta_s = gcum + BT; // [BT]
|
||||
float *q_buf = beta_s + BT; // [BK]
|
||||
|
||||
// Load state column into registers
|
||||
float s[BK];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BK; j++) {
|
||||
s[j] = state_bh[j * v_dim + v_idx];
|
||||
}
|
||||
|
||||
// Per-thread register array for corrected deltas
|
||||
float delta[BT];
|
||||
|
||||
for (int c = 0; c < num_chunks; c++) {
|
||||
const int chunk_start = c * BT;
|
||||
const int chunk_len = min(BT, seq_len - chunk_start);
|
||||
|
||||
// === Phase 1: Cooperative load of k, beta, g into shared memory ===
|
||||
for (int t = 0; t < chunk_len; t++) {
|
||||
for (int j = tid; j < BK; j += BV) {
|
||||
k_chunk[t * BK + j] = k_bh[(chunk_start + t) * BK + j];
|
||||
}
|
||||
}
|
||||
if (tid < chunk_len) {
|
||||
beta_s[tid] = beta_bh[chunk_start + tid];
|
||||
gcum[tid] = g_bh[chunk_start + tid];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// === Phase 1b: Parallel prefix sum of g (Hillis-Steele) ===
|
||||
for (int stride = 1; stride < BT; stride <<= 1) {
|
||||
float prev = 0.0f;
|
||||
if (tid < chunk_len && (int)tid >= stride)
|
||||
prev = gcum[tid - stride];
|
||||
__syncthreads();
|
||||
if (tid < chunk_len && (int)tid >= stride)
|
||||
gcum[tid] += prev;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// === Phase 2: Compute kk_dot[i][j] = dot(k[i], k[j]) for j < i ===
|
||||
// Only lower-triangular entries needed (strictly lower)
|
||||
for (int idx = tid; idx < chunk_len * chunk_len; idx += BV) {
|
||||
int i = idx / chunk_len;
|
||||
int j = idx % chunk_len;
|
||||
if (j < i) {
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < BK; d++) {
|
||||
dot = __fmaf_rn(k_chunk[i * BK + d], k_chunk[j * BK + d], dot);
|
||||
}
|
||||
kk_dot[i * BT + j] = dot;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// === Phase 3: Forward substitution (per V-column, in registers) ===
|
||||
// Computes corrected delta values via triangular solve
|
||||
for (int i = 0; i < chunk_len; i++) {
|
||||
float v_i = v_bh[(chunk_start + i) * v_dim + v_idx];
|
||||
float decay_i = expf(gcum[i]);
|
||||
float beta_i = beta_s[i];
|
||||
|
||||
// Inter-chunk contribution: state @ k[i] with decay
|
||||
float kv_mem = 0.0f;
|
||||
#pragma unroll
|
||||
for (int d = 0; d < BK; d++) {
|
||||
kv_mem = __fmaf_rn(s[d] * decay_i, k_chunk[i * BK + d], kv_mem);
|
||||
}
|
||||
|
||||
float rhs = beta_i * (v_i - kv_mem);
|
||||
|
||||
// Subtract lower-triangular contributions (intra-chunk)
|
||||
for (int j = 0; j < i; j++) {
|
||||
float a_ij = beta_i * kk_dot[i * BT + j] * expf(gcum[i] - gcum[j]);
|
||||
rhs -= a_ij * delta[j];
|
||||
}
|
||||
delta[i] = rhs;
|
||||
}
|
||||
|
||||
// === Phase 4: Output computation (per V-column) ===
|
||||
for (int i = 0; i < chunk_len; i++) {
|
||||
// Cooperatively load q[i] into shared
|
||||
for (int j = tid; j < BK; j += BV) {
|
||||
q_buf[j] = q_bh[(chunk_start + i) * BK + j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float decay_i = expf(gcum[i]);
|
||||
|
||||
// Inter-chunk contribution: q[i] @ (state * decay)
|
||||
float o_val = 0.0f;
|
||||
#pragma unroll
|
||||
for (int d = 0; d < BK; d++) {
|
||||
o_val = __fmaf_rn(q_buf[d], s[d] * decay_i, o_val);
|
||||
}
|
||||
|
||||
// Intra-chunk contribution: sum_{j<=i} dot(q[i], k[j]) * delta[j] *
|
||||
// exp(gcum[i] - gcum[j])
|
||||
for (int j = 0; j <= i; j++) {
|
||||
float qk_dot = 0.0f;
|
||||
for (int d = 0; d < BK; d++) {
|
||||
qk_dot = __fmaf_rn(q_buf[d], k_chunk[j * BK + d], qk_dot);
|
||||
}
|
||||
o_val += qk_dot * delta[j] * expf(gcum[i] - gcum[j]);
|
||||
}
|
||||
|
||||
out_bh[(chunk_start + i) * v_dim + v_idx] = o_val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// === Phase 5: State update for next chunk ===
|
||||
float g_total = gcum[chunk_len - 1];
|
||||
#pragma unroll
|
||||
for (int d = 0; d < BK; d++) {
|
||||
float s_new = s[d] * expf(g_total);
|
||||
for (int t = 0; t < chunk_len; t++) {
|
||||
s_new += k_chunk[t * BK + d] * delta[t] * expf(g_total - gcum[t]);
|
||||
}
|
||||
s[d] = s_new;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Write final state back
|
||||
#pragma unroll
|
||||
for (int j = 0; j < BK; j++) {
|
||||
state_bh[j * v_dim + v_idx] = s[j];
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void chunked_gated_delta_rule_recurrence(
|
||||
const float *q, const float *k, const float *v, const float *g,
|
||||
const float *beta, float *state, float *output, int bh, int seq_len,
|
||||
int k_dim, int v_dim, int64_t stream) {
|
||||
|
||||
const cudaStream_t custream = (cudaStream_t)stream;
|
||||
|
||||
if (k_dim == 128) {
|
||||
constexpr int BT = 64;
|
||||
constexpr int BK = 128;
|
||||
constexpr int BV = 64;
|
||||
// Shared memory: BT*BK + BT*BT + BT + BT + BK floats
|
||||
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
|
||||
|
||||
// Request extended shared memory
|
||||
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem);
|
||||
|
||||
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||
dim3 block(BV);
|
||||
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
|
||||
seq_len, v_dim);
|
||||
} else if (k_dim == 64) {
|
||||
constexpr int BT = 64;
|
||||
constexpr int BK = 64;
|
||||
constexpr int BV = 64;
|
||||
size_t smem = (BT * BK + BT * BT + 2 * BT + BK) * sizeof(float);
|
||||
|
||||
auto kernel = chunked_gated_delta_rule_kernel<BT, BK, BV>;
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem);
|
||||
|
||||
dim3 grid((v_dim + BV - 1) / BV, bh);
|
||||
dim3 block(BV);
|
||||
kernel<<<grid, block, smem, custream>>>(q, k, v, g, beta, state, output,
|
||||
seq_len, v_dim);
|
||||
} else {
|
||||
// Fallback: use the sequential kernel for unsupported k_dim
|
||||
gated_delta_rule_recurrence(q, k, v, g, beta, state, output, bh, seq_len,
|
||||
k_dim, v_dim, stream);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Kernel 2a: causal_conv1d_update (decode path, single step)
|
||||
//
|
||||
// Each thread handles one channel: shift conv_state left by 1,
|
||||
// insert new value, dot product with weight, apply SiLU.
|
||||
//
|
||||
// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
|
||||
// conv_state: [B, conv_dim, kernel_size] (in/out)
|
||||
// output: [B, conv_dim, 1]
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
__global__ void causal_conv1d_update_kernel(
|
||||
const T *__restrict__ x, // [B, conv_dim, 1]
|
||||
const T *__restrict__ weight, // [conv_dim, kernel_size]
|
||||
T *__restrict__ conv_state, // [B, conv_dim, kernel_size]
|
||||
T *__restrict__ output, // [B, conv_dim, 1]
|
||||
int batch_size, int conv_dim, int kernel_size) {
|
||||
|
||||
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int b = blockIdx.y;
|
||||
|
||||
if (ch >= conv_dim || b >= batch_size)
|
||||
return;
|
||||
|
||||
// Pointer to this batch/channel's conv state
|
||||
T *cs = conv_state + (b * conv_dim + ch) * kernel_size;
|
||||
const T *w = weight + ch * kernel_size;
|
||||
|
||||
// Shift state left by 1
|
||||
for (int i = 0; i < kernel_size - 1; i++) {
|
||||
cs[i] = cs[i + 1];
|
||||
}
|
||||
// Insert new value
|
||||
cs[kernel_size - 1] = x[b * conv_dim + ch];
|
||||
|
||||
// Dot product with weight
|
||||
float acc = 0.0f;
|
||||
for (int i = 0; i < kernel_size; i++) {
|
||||
acc += (float)cs[i] * (float)w[i];
|
||||
}
|
||||
|
||||
// SiLU activation: x * sigmoid(x)
|
||||
float sig = 1.0f / (1.0f + expf(-acc));
|
||||
float result = acc * sig;
|
||||
|
||||
output[b * conv_dim + ch] = (T)result;
|
||||
}
|
||||
|
||||
extern "C" void causal_conv1d_update(const void *x, const void *weight,
|
||||
void *conv_state, void *output,
|
||||
int batch_size, int conv_dim,
|
||||
int kernel_size, int dtype,
|
||||
int64_t stream) {
|
||||
const cudaStream_t custream = (cudaStream_t)stream;
|
||||
dim3 block(256);
|
||||
dim3 grid((conv_dim + 255) / 256, batch_size);
|
||||
|
||||
if (dtype == 0) {
|
||||
// f16
|
||||
causal_conv1d_update_kernel<__half><<<grid, block, 0, custream>>>(
|
||||
(const __half *)x, (const __half *)weight, (__half *)conv_state,
|
||||
(__half *)output, batch_size, conv_dim, kernel_size);
|
||||
} else {
|
||||
// bf16
|
||||
causal_conv1d_update_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
|
||||
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
|
||||
(__nv_bfloat16 *)conv_state, (__nv_bfloat16 *)output, batch_size,
|
||||
conv_dim, kernel_size);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Kernel 2b: causal_conv1d_full (prefill path)
|
||||
//
|
||||
// Each thread handles one (channel, position): causal window with
|
||||
// zero-padding, dot product with weight, SiLU.
|
||||
// A second pass writes the conv_state from the last kernel_size positions.
|
||||
//
|
||||
// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
|
||||
// conv_state_out: [B, conv_dim, kernel_size] output: [B, conv_dim, S]
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
__global__ void causal_conv1d_full_kernel(
|
||||
const T *__restrict__ x, // [B, conv_dim, S]
|
||||
const T *__restrict__ weight, // [conv_dim, kernel_size]
|
||||
T *__restrict__ output, // [B, conv_dim, S]
|
||||
int batch_size, int conv_dim, int seq_len, int kernel_size) {
|
||||
|
||||
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int pos = blockIdx.y;
|
||||
const int b = blockIdx.z;
|
||||
|
||||
if (ch >= conv_dim || pos >= seq_len || b >= batch_size)
|
||||
return;
|
||||
|
||||
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
|
||||
const T *w = weight + ch * kernel_size;
|
||||
|
||||
// Causal convolution: sum over kernel_size window ending at pos
|
||||
float acc = 0.0f;
|
||||
for (int i = 0; i < kernel_size; i++) {
|
||||
int src_pos = pos - (kernel_size - 1) + i;
|
||||
float x_val = (src_pos >= 0) ? (float)x_bch[src_pos] : 0.0f;
|
||||
acc += x_val * (float)w[i];
|
||||
}
|
||||
|
||||
// SiLU
|
||||
float sig = 1.0f / (1.0f + expf(-acc));
|
||||
float result = acc * sig;
|
||||
|
||||
output[(b * conv_dim + ch) * seq_len + pos] = (T)result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void save_conv_state_kernel(
|
||||
const T *__restrict__ x, // [B, conv_dim, S]
|
||||
T *__restrict__ conv_state_out, // [B, conv_dim, kernel_size]
|
||||
int batch_size, int conv_dim, int seq_len, int kernel_size) {
|
||||
|
||||
const int ch = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int b = blockIdx.y;
|
||||
|
||||
if (ch >= conv_dim || b >= batch_size)
|
||||
return;
|
||||
|
||||
const T *x_bch = x + (b * conv_dim + ch) * seq_len;
|
||||
T *cs = conv_state_out + (b * conv_dim + ch) * kernel_size;
|
||||
|
||||
// Save last kernel_size positions (zero-pad if seq_len < kernel_size)
|
||||
int pad = kernel_size - seq_len;
|
||||
for (int i = 0; i < kernel_size; i++) {
|
||||
if (i < pad) {
|
||||
cs[i] = (T)0.0f;
|
||||
} else {
|
||||
cs[i] = x_bch[seq_len - kernel_size + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void causal_conv1d_full(const void *x, const void *weight,
|
||||
void *conv_state_out, void *output,
|
||||
int batch_size, int conv_dim, int seq_len,
|
||||
int kernel_size, int dtype, int64_t stream) {
|
||||
const cudaStream_t custream = (cudaStream_t)stream;
|
||||
|
||||
// Main convolution kernel
|
||||
dim3 block(256);
|
||||
dim3 grid((conv_dim + 255) / 256, seq_len, batch_size);
|
||||
|
||||
if (dtype == 0) {
|
||||
causal_conv1d_full_kernel<__half><<<grid, block, 0, custream>>>(
|
||||
(const __half *)x, (const __half *)weight, (__half *)output, batch_size,
|
||||
conv_dim, seq_len, kernel_size);
|
||||
// Save conv state
|
||||
dim3 grid2((conv_dim + 255) / 256, batch_size);
|
||||
save_conv_state_kernel<__half><<<grid2, block, 0, custream>>>(
|
||||
(const __half *)x, (__half *)conv_state_out, batch_size, conv_dim,
|
||||
seq_len, kernel_size);
|
||||
} else {
|
||||
causal_conv1d_full_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
|
||||
(const __nv_bfloat16 *)x, (const __nv_bfloat16 *)weight,
|
||||
(__nv_bfloat16 *)output, batch_size, conv_dim, seq_len, kernel_size);
|
||||
dim3 grid2((conv_dim + 255) / 256, batch_size);
|
||||
save_conv_state_kernel<__nv_bfloat16><<<grid2, block, 0, custream>>>(
|
||||
(const __nv_bfloat16 *)x, (__nv_bfloat16 *)conv_state_out, batch_size,
|
||||
conv_dim, seq_len, kernel_size);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Kernel 3: fused_gdn_gating
|
||||
//
|
||||
// Fuses: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
|
||||
// a_log and dt_bias are per-head (broadcast over batch*seq).
|
||||
//
|
||||
// b, a: [total] a_log, dt_bias: [num_heads]
|
||||
// beta_out, g_out: [total]
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
__global__ void
|
||||
fused_gdn_gating_kernel(const T *__restrict__ b, // [total]
|
||||
const T *__restrict__ a, // [total]
|
||||
const float *__restrict__ a_log, // [num_heads]
|
||||
const float *__restrict__ dt_bias, // [num_heads]
|
||||
T *__restrict__ beta_out, // [total]
|
||||
T *__restrict__ g_out, // [total]
|
||||
int total_elements, int num_heads) {
|
||||
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= total_elements)
|
||||
return;
|
||||
|
||||
// Head index: elements are laid out as [..., num_heads]
|
||||
int head_idx = idx % num_heads;
|
||||
|
||||
// beta = sigmoid(b)
|
||||
float b_val = (float)b[idx];
|
||||
float beta = 1.0f / (1.0f + expf(-b_val));
|
||||
|
||||
// g = -exp(a_log) * softplus(a + dt_bias)
|
||||
float a_val = (float)a[idx];
|
||||
float a_log_val = a_log[head_idx];
|
||||
float dt_bias_val = dt_bias[head_idx];
|
||||
|
||||
float sp_input = a_val + dt_bias_val;
|
||||
float softplus_val = logf(1.0f + expf(sp_input));
|
||||
float g_val = -expf(a_log_val) * softplus_val;
|
||||
|
||||
beta_out[idx] = (T)beta;
|
||||
g_out[idx] = (T)g_val;
|
||||
}
|
||||
|
||||
extern "C" void fused_gdn_gating(const void *b, const void *a,
|
||||
const float *a_log, const float *dt_bias,
|
||||
void *beta_out, void *g_out,
|
||||
int total_elements, int num_heads, int dtype,
|
||||
int64_t stream) {
|
||||
const cudaStream_t custream = (cudaStream_t)stream;
|
||||
dim3 block(256);
|
||||
dim3 grid((total_elements + 255) / 256);
|
||||
|
||||
if (dtype == 0) {
|
||||
fused_gdn_gating_kernel<__half><<<grid, block, 0, custream>>>(
|
||||
(const __half *)b, (const __half *)a, a_log, dt_bias,
|
||||
(__half *)beta_out, (__half *)g_out, total_elements, num_heads);
|
||||
} else {
|
||||
fused_gdn_gating_kernel<__nv_bfloat16><<<grid, block, 0, custream>>>(
|
||||
(const __nv_bfloat16 *)b, (const __nv_bfloat16 *)a, a_log, dt_bias,
|
||||
(__nv_bfloat16 *)beta_out, (__nv_bfloat16 *)g_out, total_elements,
|
||||
num_heads);
|
||||
}
|
||||
}
|
||||
@@ -1,486 +0,0 @@
|
||||
//! Rust wrappers around the Gated DeltaNet CUDA kernels in `gdn.cu`.
|
||||
//!
|
||||
//! Ported verbatim from `EricLBuehler/mistral.rs` under MIT terms.
|
||||
//! Upstream path: `mistralrs-core/src/cuda/gdn.rs`. The only edits in
|
||||
//! this file are this header comment — the FFI path module name is
|
||||
//! `crate::cuda::ffi`, identical to upstream's layout.
|
||||
|
||||
#![allow(clippy::cast_possible_truncation)]
|
||||
|
||||
use candle_core::{Result, Tensor};
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
use candle_core::DType;
|
||||
|
||||
/// CUDA-accelerated gated delta rule recurrence.
|
||||
///
|
||||
/// Inputs (all contiguous, f32):
|
||||
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
|
||||
/// state: [BH, K, V] (mutated in place)
|
||||
///
|
||||
/// Returns: output [BH, S, V]
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn gated_delta_rule_recurrence_cuda(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
g: &Tensor,
|
||||
beta: &Tensor,
|
||||
state: &mut Tensor,
|
||||
) -> Result<Tensor> {
|
||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||
use candle_core as candle;
|
||||
|
||||
let (bh, seq_len, k_dim) = q.dims3()?;
|
||||
let v_dim = v.dim(2)?;
|
||||
|
||||
let dev = q.device().as_cuda_device()?;
|
||||
|
||||
let (q_s, q_l) = q.storage_and_layout();
|
||||
let q_s = match &*q_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("q must be a cuda tensor"),
|
||||
};
|
||||
let q_offset = q_l.start_offset();
|
||||
|
||||
let (k_s, k_l) = k.storage_and_layout();
|
||||
let k_s = match &*k_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("k must be a cuda tensor"),
|
||||
};
|
||||
let k_offset = k_l.start_offset();
|
||||
|
||||
let (v_s, v_l) = v.storage_and_layout();
|
||||
let v_s = match &*v_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("v must be a cuda tensor"),
|
||||
};
|
||||
let v_offset = v_l.start_offset();
|
||||
|
||||
let (g_s, g_l) = g.storage_and_layout();
|
||||
let g_s = match &*g_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("g must be a cuda tensor"),
|
||||
};
|
||||
let g_offset = g_l.start_offset();
|
||||
|
||||
let (beta_s, beta_l) = beta.storage_and_layout();
|
||||
let beta_s = match &*beta_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("beta must be a cuda tensor"),
|
||||
};
|
||||
let beta_offset = beta_l.start_offset();
|
||||
|
||||
let (state_s, state_l) = state.storage_and_layout();
|
||||
let state_s = match &*state_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("state must be a cuda tensor"),
|
||||
};
|
||||
let state_offset = state_l.start_offset();
|
||||
|
||||
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
|
||||
|
||||
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||
|
||||
unsafe {
|
||||
crate::cuda::ffi::gated_delta_rule_recurrence(
|
||||
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
|
||||
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
|
||||
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
|
||||
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
|
||||
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
|
||||
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
|
||||
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
|
||||
bh as i32,
|
||||
seq_len as i32,
|
||||
k_dim as i32,
|
||||
v_dim as i32,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
// The kernel wrote state in-place via the raw pointer; rewrap
|
||||
// (state tensor's underlying CudaSlice was modified directly)
|
||||
|
||||
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||
Ok(Tensor::from((
|
||||
candle::Storage::Cuda(output_storage),
|
||||
(bh, seq_len, v_dim),
|
||||
)))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
#[allow(unused)]
|
||||
pub fn gated_delta_rule_recurrence_cuda(
|
||||
_q: &Tensor,
|
||||
_k: &Tensor,
|
||||
_v: &Tensor,
|
||||
_g: &Tensor,
|
||||
_beta: &Tensor,
|
||||
_state: &mut Tensor,
|
||||
) -> Result<Tensor> {
|
||||
candle_core::bail!("gated_delta_rule_recurrence_cuda requires the cuda feature")
|
||||
}
|
||||
|
||||
/// CUDA-accelerated chunked gated delta rule recurrence (prefill optimization).
|
||||
///
|
||||
/// Processes prefill tokens in 64-token chunks instead of one at a time.
|
||||
/// Same interface as `gated_delta_rule_recurrence_cuda`.
|
||||
///
|
||||
/// Inputs (all contiguous, f32):
|
||||
/// q, k: [BH, S, K] v: [BH, S, V] g, beta: [BH, S]
|
||||
/// state: [BH, K, V] (mutated in place)
|
||||
///
|
||||
/// Returns: output [BH, S, V]
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn chunked_gated_delta_rule_recurrence_cuda(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
g: &Tensor,
|
||||
beta: &Tensor,
|
||||
state: &mut Tensor,
|
||||
) -> Result<Tensor> {
|
||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||
use candle_core as candle;
|
||||
|
||||
let (bh, seq_len, k_dim) = q.dims3()?;
|
||||
let v_dim = v.dim(2)?;
|
||||
|
||||
let dev = q.device().as_cuda_device()?;
|
||||
|
||||
let (q_s, q_l) = q.storage_and_layout();
|
||||
let q_s = match &*q_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("q must be a cuda tensor"),
|
||||
};
|
||||
let q_offset = q_l.start_offset();
|
||||
|
||||
let (k_s, k_l) = k.storage_and_layout();
|
||||
let k_s = match &*k_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("k must be a cuda tensor"),
|
||||
};
|
||||
let k_offset = k_l.start_offset();
|
||||
|
||||
let (v_s, v_l) = v.storage_and_layout();
|
||||
let v_s = match &*v_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("v must be a cuda tensor"),
|
||||
};
|
||||
let v_offset = v_l.start_offset();
|
||||
|
||||
let (g_s, g_l) = g.storage_and_layout();
|
||||
let g_s = match &*g_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("g must be a cuda tensor"),
|
||||
};
|
||||
let g_offset = g_l.start_offset();
|
||||
|
||||
let (beta_s, beta_l) = beta.storage_and_layout();
|
||||
let beta_s = match &*beta_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("beta must be a cuda tensor"),
|
||||
};
|
||||
let beta_offset = beta_l.start_offset();
|
||||
|
||||
let (state_s, state_l) = state.storage_and_layout();
|
||||
let state_s = match &*state_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("state must be a cuda tensor"),
|
||||
};
|
||||
let state_offset = state_l.start_offset();
|
||||
|
||||
let output_buf = unsafe { dev.alloc::<f32>(bh * seq_len * v_dim) }?;
|
||||
|
||||
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||
|
||||
unsafe {
|
||||
crate::cuda::ffi::chunked_gated_delta_rule_recurrence(
|
||||
q_s.slice(q_offset..).device_ptr(q_s.stream()).0 as *const f32,
|
||||
k_s.slice(k_offset..).device_ptr(k_s.stream()).0 as *const f32,
|
||||
v_s.slice(v_offset..).device_ptr(v_s.stream()).0 as *const f32,
|
||||
g_s.slice(g_offset..).device_ptr(g_s.stream()).0 as *const f32,
|
||||
beta_s.slice(beta_offset..).device_ptr(beta_s.stream()).0 as *const f32,
|
||||
state_s.slice(state_offset..).device_ptr(state_s.stream()).0 as *mut f32,
|
||||
output_buf.device_ptr(output_buf.stream()).0 as *mut f32,
|
||||
bh as i32,
|
||||
seq_len as i32,
|
||||
k_dim as i32,
|
||||
v_dim as i32,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||
Ok(Tensor::from((
|
||||
candle::Storage::Cuda(output_storage),
|
||||
(bh, seq_len, v_dim),
|
||||
)))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
#[allow(unused)]
|
||||
pub fn chunked_gated_delta_rule_recurrence_cuda(
|
||||
_q: &Tensor,
|
||||
_k: &Tensor,
|
||||
_v: &Tensor,
|
||||
_g: &Tensor,
|
||||
_beta: &Tensor,
|
||||
_state: &mut Tensor,
|
||||
) -> Result<Tensor> {
|
||||
candle_core::bail!("chunked_gated_delta_rule_recurrence_cuda requires the cuda feature")
|
||||
}
|
||||
|
||||
/// CUDA-accelerated causal conv1d (both update and full paths).
|
||||
///
|
||||
/// For update (is_update=true):
|
||||
/// x: [B, conv_dim, 1] weight: [conv_dim, kernel_size]
|
||||
/// conv_state: [B, conv_dim, kernel_size] (mutated in place for update)
|
||||
/// Returns: (output [B, conv_dim, 1], updated conv_state)
|
||||
///
|
||||
/// For full (is_update=false):
|
||||
/// x: [B, conv_dim, S] weight: [conv_dim, kernel_size]
|
||||
/// Returns: (output [B, conv_dim, S], new conv_state [B, conv_dim, kernel_size])
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn causal_conv1d_cuda(
|
||||
x: &Tensor,
|
||||
weight: &Tensor,
|
||||
conv_state: &Tensor,
|
||||
kernel_size: usize,
|
||||
is_update: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||
use candle_core as candle;
|
||||
use core::ffi::c_void;
|
||||
fn cuda_fwd<
|
||||
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
|
||||
>(
|
||||
x: &Tensor,
|
||||
weight: &Tensor,
|
||||
conv_state: &Tensor,
|
||||
kernel_size: usize,
|
||||
is_update: bool,
|
||||
dtype_code: i32,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let dev = x.device().as_cuda_device()?;
|
||||
let (batch_size, conv_dim, seq_len) = x.dims3()?;
|
||||
|
||||
let (x_s, x_l) = x.storage_and_layout();
|
||||
let x_s = match &*x_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||
_ => candle::bail!("x must be a cuda tensor"),
|
||||
};
|
||||
let x_offset = x_l.start_offset();
|
||||
|
||||
let (w_s, w_l) = weight.storage_and_layout();
|
||||
let w_s = match &*w_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||
_ => candle::bail!("weight must be a cuda tensor"),
|
||||
};
|
||||
let w_offset = w_l.start_offset();
|
||||
|
||||
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||
|
||||
if is_update {
|
||||
// Clone conv_state so the kernel can mutate it in place
|
||||
let conv_state_new = conv_state.clone();
|
||||
|
||||
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim) }?;
|
||||
|
||||
// Scope the borrow of conv_state_new so we can move it later
|
||||
{
|
||||
let (cs_s, cs_l) = conv_state_new.storage_and_layout();
|
||||
let cs_s = match &*cs_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||
_ => candle::bail!("conv_state must be a cuda tensor"),
|
||||
};
|
||||
let cs_offset = cs_l.start_offset();
|
||||
|
||||
unsafe {
|
||||
crate::cuda::ffi::causal_conv1d_update(
|
||||
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
|
||||
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
|
||||
cs_s.slice(cs_offset..).device_ptr(cs_s.stream()).0 as *mut c_void,
|
||||
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
|
||||
batch_size as i32,
|
||||
conv_dim as i32,
|
||||
kernel_size as i32,
|
||||
dtype_code,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||
let output = Tensor::from((
|
||||
candle::Storage::Cuda(output_storage),
|
||||
(batch_size, conv_dim, 1usize),
|
||||
));
|
||||
|
||||
Ok((output, conv_state_new))
|
||||
} else {
|
||||
// Full path: allocate new conv_state and output
|
||||
let output_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * seq_len) }?;
|
||||
let cs_buf = unsafe { dev.alloc::<T>(batch_size * conv_dim * kernel_size) }?;
|
||||
|
||||
unsafe {
|
||||
crate::cuda::ffi::causal_conv1d_full(
|
||||
x_s.slice(x_offset..).device_ptr(x_s.stream()).0 as *const c_void,
|
||||
w_s.slice(w_offset..).device_ptr(w_s.stream()).0 as *const c_void,
|
||||
cs_buf.device_ptr(cs_buf.stream()).0 as *mut c_void,
|
||||
output_buf.device_ptr(output_buf.stream()).0 as *mut c_void,
|
||||
batch_size as i32,
|
||||
conv_dim as i32,
|
||||
seq_len as i32,
|
||||
kernel_size as i32,
|
||||
dtype_code,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
let output_storage = candle::CudaStorage::wrap_cuda_slice(output_buf, dev.clone());
|
||||
let output = Tensor::from((
|
||||
candle::Storage::Cuda(output_storage),
|
||||
(batch_size, conv_dim, seq_len),
|
||||
));
|
||||
|
||||
let cs_storage = candle::CudaStorage::wrap_cuda_slice(cs_buf, dev.clone());
|
||||
let new_conv_state = Tensor::from((
|
||||
candle::Storage::Cuda(cs_storage),
|
||||
(batch_size, conv_dim, kernel_size),
|
||||
));
|
||||
|
||||
Ok((output, new_conv_state))
|
||||
}
|
||||
}
|
||||
|
||||
match x.dtype() {
|
||||
DType::F16 => cuda_fwd::<half::f16>(x, weight, conv_state, kernel_size, is_update, 0),
|
||||
DType::BF16 => cuda_fwd::<half::bf16>(x, weight, conv_state, kernel_size, is_update, 1),
|
||||
other => candle_core::bail!("causal_conv1d_cuda only supports f16/bf16, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
#[allow(unused)]
|
||||
pub fn causal_conv1d_cuda(
|
||||
_x: &Tensor,
|
||||
_weight: &Tensor,
|
||||
_conv_state: &Tensor,
|
||||
_kernel_size: usize,
|
||||
_is_update: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
candle_core::bail!("causal_conv1d_cuda requires the cuda feature")
|
||||
}
|
||||
|
||||
/// CUDA-accelerated fused GDN gating computation.
|
||||
///
|
||||
/// Computes: beta = sigmoid(b), g = -exp(a_log) * softplus(a + dt_bias)
|
||||
///
|
||||
/// b, a: [total_elements] in f16/bf16
|
||||
/// a_log, dt_bias: [num_heads] in f32
|
||||
///
|
||||
/// Returns: (beta, g) in original dtype
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn fused_gdn_gating_cuda(
|
||||
b: &Tensor,
|
||||
a: &Tensor,
|
||||
a_log: &Tensor,
|
||||
dt_bias: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||
use candle_core as candle;
|
||||
use core::ffi::c_void;
|
||||
|
||||
fn cuda_fwd<
|
||||
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
|
||||
>(
|
||||
b: &Tensor,
|
||||
a: &Tensor,
|
||||
a_log: &Tensor,
|
||||
dt_bias: &Tensor,
|
||||
dtype_code: i32,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let total_elements = b.elem_count();
|
||||
let num_heads = a_log.elem_count();
|
||||
let shape = b.shape().clone();
|
||||
let dev = b.device().as_cuda_device()?;
|
||||
|
||||
let (b_s, b_l) = b.storage_and_layout();
|
||||
let b_s = match &*b_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||
_ => candle::bail!("b must be a cuda tensor"),
|
||||
};
|
||||
let b_offset = b_l.start_offset();
|
||||
|
||||
let (a_s, a_l) = a.storage_and_layout();
|
||||
let a_s = match &*a_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
|
||||
_ => candle::bail!("a must be a cuda tensor"),
|
||||
};
|
||||
let a_offset = a_l.start_offset();
|
||||
|
||||
let (alog_s, alog_l) = a_log.storage_and_layout();
|
||||
let alog_s = match &*alog_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("a_log must be a cuda tensor"),
|
||||
};
|
||||
let alog_offset = alog_l.start_offset();
|
||||
|
||||
let (dtb_s, dtb_l) = dt_bias.storage_and_layout();
|
||||
let dtb_s = match &*dtb_s {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("dt_bias must be a cuda tensor"),
|
||||
};
|
||||
let dtb_offset = dtb_l.start_offset();
|
||||
|
||||
let beta_buf = unsafe { dev.alloc::<T>(total_elements) }?;
|
||||
let g_buf = unsafe { dev.alloc::<T>(total_elements) }?;
|
||||
|
||||
let stream = dev.cuda_stream().cu_stream() as i64;
|
||||
|
||||
unsafe {
|
||||
crate::cuda::ffi::fused_gdn_gating(
|
||||
b_s.slice(b_offset..).device_ptr(b_s.stream()).0 as *const c_void,
|
||||
a_s.slice(a_offset..).device_ptr(a_s.stream()).0 as *const c_void,
|
||||
alog_s.slice(alog_offset..).device_ptr(alog_s.stream()).0 as *const f32,
|
||||
dtb_s.slice(dtb_offset..).device_ptr(dtb_s.stream()).0 as *const f32,
|
||||
beta_buf.device_ptr(beta_buf.stream()).0 as *mut c_void,
|
||||
g_buf.device_ptr(g_buf.stream()).0 as *mut c_void,
|
||||
total_elements as i32,
|
||||
num_heads as i32,
|
||||
dtype_code,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
let beta_storage = candle::CudaStorage::wrap_cuda_slice(beta_buf, dev.clone());
|
||||
let beta = Tensor::from((candle::Storage::Cuda(beta_storage), shape.clone()));
|
||||
|
||||
let g_storage = candle::CudaStorage::wrap_cuda_slice(g_buf, dev.clone());
|
||||
let g = Tensor::from((candle::Storage::Cuda(g_storage), shape));
|
||||
|
||||
Ok((beta, g))
|
||||
}
|
||||
|
||||
match b.dtype() {
|
||||
DType::F16 => cuda_fwd::<half::f16>(b, a, a_log, dt_bias, 0),
|
||||
DType::BF16 => cuda_fwd::<half::bf16>(b, a, a_log, dt_bias, 1),
|
||||
other => candle_core::bail!(
|
||||
"fused_gdn_gating_cuda only supports f16/bf16, got {:?}",
|
||||
other
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
#[allow(unused)]
|
||||
pub fn fused_gdn_gating_cuda(
|
||||
_b: &Tensor,
|
||||
_a: &Tensor,
|
||||
_a_log: &Tensor,
|
||||
_dt_bias: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
candle_core::bail!("fused_gdn_gating_cuda requires the cuda feature")
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
//! CUDA kernels and their Rust wrappers.
|
||||
//!
|
||||
//! Currently scoped to what we need for Qwen3-Next (`qwen3_5`)
|
||||
//! inference performance — the Gated DeltaNet kernels ported from
|
||||
//! `EricLBuehler/mistral.rs` (MIT). Each kernel lives in a `.cu`
|
||||
//! file alongside this module; `build.rs` compiles them all into a
|
||||
//! static lib via `cudaforge` and links it under the `cuda` feature.
|
||||
//!
|
||||
//! When we absorb more upstream kernels (MoE GEMM, top-k, Mamba SSM,
|
||||
//! etc.) they land here in their own `.cu` + `.rs` pairs.
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
pub mod ffi;
|
||||
#[cfg(feature = "cuda")]
|
||||
pub mod gdn;
|
||||
@@ -100,87 +100,6 @@ pub fn parse_health_info(csv_output: &str) -> Result<Vec<DeviceHealth>> {
|
||||
Ok(devices)
|
||||
}
|
||||
|
||||
// ── Driver/library mismatch preflight (#19) ─────────────────────────
|
||||
|
||||
/// Classify a failed nvidia-smi invocation: is it the classic
|
||||
/// "Driver/library version mismatch" (userspace libs updated, kernel
|
||||
/// module not reloaded — every CUDA call on the host is dead until a
|
||||
/// reboot)? Returns the userspace NVML library version when the
|
||||
/// message carries one ("NVML library version: 580.159"), or
|
||||
/// `Some("unknown")` for a mismatch without a parsable version.
|
||||
/// `None` for any other failure — other errors (no devices, perms)
|
||||
/// are NOT the mismatch and must not trigger the loud diagnosis.
|
||||
pub fn classify_driver_mismatch(combined_output: &str) -> Option<String> {
|
||||
if !combined_output.contains("Driver/library version mismatch") {
|
||||
return None;
|
||||
}
|
||||
let userspace = combined_output
|
||||
.lines()
|
||||
.find_map(|l| l.trim().strip_prefix("NVML library version:"))
|
||||
.map(|v| v.trim().to_string())
|
||||
.filter(|v| !v.is_empty())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
Some(userspace)
|
||||
}
|
||||
|
||||
/// Extract the loaded kernel module's driver version from
|
||||
/// `/proc/driver/nvidia/version` contents. Typical first line:
|
||||
///
|
||||
/// ```text
|
||||
/// NVRM version: NVIDIA UNIX Open Kernel Module for x86_64 580.159.03 Release Build (...)
|
||||
/// ```
|
||||
pub fn parse_kernel_module_version(proc_contents: &str) -> Option<String> {
|
||||
let is_numeric = |p: &str| !p.is_empty() && p.chars().all(|c| c.is_ascii_digit());
|
||||
let line = proc_contents
|
||||
.lines()
|
||||
.find(|l| l.starts_with("NVRM version:"))?;
|
||||
line.split_whitespace()
|
||||
.find(|tok| {
|
||||
let mut parts = tok.split('.');
|
||||
parts.next().is_some_and(is_numeric) && parts.next().is_some_and(is_numeric)
|
||||
})
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
|
||||
/// Render the operator-actionable mismatch description carried in
|
||||
/// `DiscoveryResponse::cuda_unavailable_reason` and logged at startup.
|
||||
pub fn mismatch_reason(userspace: &str, kernel_module: Option<&str>) -> String {
|
||||
format!(
|
||||
"host NVIDIA driver/library mismatch (userspace NVML {userspace} vs loaded kernel \
|
||||
module {}) — reboot the host to reload the kernel module; all CUDA inference is \
|
||||
unavailable until then",
|
||||
kernel_module.unwrap_or("unknown")
|
||||
)
|
||||
}
|
||||
|
||||
/// Outcome of an nvidia-smi invocation, distinguishing "binary not
|
||||
/// present" (CPU-only host, not an error) from "present but failing"
|
||||
/// (possible driver mismatch — worth classifying).
|
||||
enum SmiOutcome {
|
||||
Ok(String),
|
||||
Failed(String),
|
||||
Absent,
|
||||
}
|
||||
|
||||
async fn run_nvidia_smi(args: &[&str]) -> SmiOutcome {
|
||||
match tokio::process::Command::new("nvidia-smi")
|
||||
.args(args)
|
||||
.output()
|
||||
.await
|
||||
{
|
||||
Err(_) => SmiOutcome::Absent,
|
||||
Ok(out) if out.status.success() => {
|
||||
SmiOutcome::Ok(String::from_utf8_lossy(&out.stdout).to_string())
|
||||
}
|
||||
Ok(out) => {
|
||||
let mut combined = String::from_utf8_lossy(&out.stdout).to_string();
|
||||
combined.push('\n');
|
||||
combined.push_str(&String::from_utf8_lossy(&out.stderr));
|
||||
SmiOutcome::Failed(combined)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Command execution wrappers ──────────────────────────────────────
|
||||
|
||||
async fn run_command(cmd: &str, args: &[&str]) -> Result<String> {
|
||||
@@ -220,42 +139,23 @@ pub async fn discover_system() -> Result<DiscoveryResponse> {
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
let (devices, driver_version, cuda_unavailable_reason) = match run_nvidia_smi(&[
|
||||
&format!("--query-gpu={NVIDIA_SMI_DISCOVERY_QUERY}"),
|
||||
"--format=csv,noheader,nounits",
|
||||
])
|
||||
let (devices, driver_version) = match run_command_optional(
|
||||
"nvidia-smi",
|
||||
&[
|
||||
&format!("--query-gpu={NVIDIA_SMI_DISCOVERY_QUERY}"),
|
||||
"--format=csv,noheader,nounits",
|
||||
],
|
||||
)
|
||||
.await
|
||||
{
|
||||
SmiOutcome::Ok(output) => {
|
||||
Some(output) => {
|
||||
let devs = parse_gpu_info(&output).unwrap_or_default();
|
||||
let driver = parse_driver_version(&output);
|
||||
(devs, driver, None)
|
||||
(devs, driver)
|
||||
}
|
||||
SmiOutcome::Absent => {
|
||||
None => {
|
||||
tracing::info!("nvidia-smi not found — no GPU devices discovered");
|
||||
(vec![], None, None)
|
||||
}
|
||||
SmiOutcome::Failed(combined) => {
|
||||
// nvidia-smi exists but can't talk to the driver. The case
|
||||
// worth diagnosing precisely is the userspace↔kernel-module
|
||||
// version skew after an un-rebooted driver update (#19) —
|
||||
// every CUDA call on the host fails until a reboot, and
|
||||
// without this classification it surfaces as a cryptic
|
||||
// NCCL/cuInit error deep inside the first model load.
|
||||
let reason = classify_driver_mismatch(&combined).map(|userspace| {
|
||||
let kmod = std::fs::read_to_string("/proc/driver/nvidia/version")
|
||||
.ok()
|
||||
.as_deref()
|
||||
.and_then(parse_kernel_module_version);
|
||||
mismatch_reason(&userspace, kmod.as_deref())
|
||||
});
|
||||
if reason.is_none() {
|
||||
tracing::warn!(
|
||||
output = %combined.trim(),
|
||||
"nvidia-smi present but failing — no GPU devices discovered"
|
||||
);
|
||||
}
|
||||
(vec![], None, reason)
|
||||
(vec![], None)
|
||||
}
|
||||
};
|
||||
|
||||
@@ -272,7 +172,6 @@ pub async fn discover_system() -> Result<DiscoveryResponse> {
|
||||
driver_version,
|
||||
devices,
|
||||
harnesses: vec![], // populated by harness registry in Phase 8
|
||||
cuda_unavailable_reason,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -373,63 +272,4 @@ mod tests {
|
||||
assert_eq!(health[1].vram_used_mb, 4096);
|
||||
assert_eq!(health[1].temp_c, 58);
|
||||
}
|
||||
|
||||
// ── #19 driver/library mismatch preflight ────────────────────────
|
||||
|
||||
#[test]
|
||||
fn classify_driver_mismatch_detects_and_extracts_nvml_version() {
|
||||
// Verbatim shape of nvidia-smi's failure output on a host
|
||||
// whose userspace libs were updated without a reboot.
|
||||
let out = "Failed to initialize NVML: Driver/library version mismatch\n\
|
||||
NVML library version: 580.159\n";
|
||||
assert_eq!(classify_driver_mismatch(out).as_deref(), Some("580.159"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_driver_mismatch_without_version_line() {
|
||||
let out = "Failed to initialize NVML: Driver/library version mismatch\n";
|
||||
assert_eq!(classify_driver_mismatch(out).as_deref(), Some("unknown"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_driver_mismatch_ignores_other_failures() {
|
||||
// Other nvidia-smi failures must NOT be diagnosed as the
|
||||
// mismatch (no false positives on healthy or odd hosts).
|
||||
for out in [
|
||||
"No devices were found\n",
|
||||
"Failed to initialize NVML: Insufficient Permissions\n",
|
||||
"NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver.\n",
|
||||
"",
|
||||
] {
|
||||
assert_eq!(
|
||||
classify_driver_mismatch(out),
|
||||
None,
|
||||
"false positive on: {out:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_kernel_module_version_from_proc() {
|
||||
let proc = "NVRM version: NVIDIA UNIX Open Kernel Module for x86_64 580.159.03 Release Build (dvs-builder@U22-I3-AE24-12-2) Tue May 12 21:03:35 UTC 2026\n\
|
||||
GCC version: gcc version 15.2.1 20251022 (Red Hat 15.2.1-3) (GCC)\n";
|
||||
assert_eq!(
|
||||
parse_kernel_module_version(proc).as_deref(),
|
||||
Some("580.159.03")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_kernel_module_version_absent() {
|
||||
assert_eq!(parse_kernel_module_version(""), None);
|
||||
assert_eq!(parse_kernel_module_version("GCC version: gcc 15\n"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mismatch_reason_is_operator_actionable() {
|
||||
let reason = mismatch_reason("580.159", Some("580.159.03"));
|
||||
assert!(reason.contains("580.159"));
|
||||
assert!(reason.contains("580.159.03"));
|
||||
assert!(reason.contains("reboot"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
//! Custom architecture implementations.
|
||||
//!
|
||||
//! When candle-transformers ships a model family unchanged
|
||||
//! (`models::llama`, `models::qwen3`, `models::qwen3_moe`, etc.), the
|
||||
//! handler in `harness/candle.rs` just wraps the upstream type in a
|
||||
//! `ModelArch` variant.
|
||||
//!
|
||||
//! When candle has nothing for the architecture and we have to write
|
||||
//! it from scratch — Qwen3-Next / Qwen3.6 (`qwen3_5`) being the
|
||||
//! motivating example — the implementation lands here, one file per
|
||||
//! architecture.
|
||||
//!
|
||||
//! Each architecture module is expected to expose:
|
||||
//! - A `Config` type deserialised from the model's `config.json`
|
||||
//! (some architectures nest the real hyperparams under `text_config`,
|
||||
//! in which case the module owns the unwrapping).
|
||||
//! - A `ForCausalLM` struct with `new`, `forward(&mut self, x, offset)
|
||||
//! -> Result<Tensor>`, and `clear_kv_cache(&mut self)`.
|
||||
//!
|
||||
//! TP-aware analogues live in `harness/tp/tp_<family>.rs` and follow
|
||||
//! the pattern set by `tp_qwen3.rs`.
|
||||
|
||||
pub mod qwen3_5;
|
||||
@@ -1,152 +0,0 @@
|
||||
//! Qwen3-Next decoder layer.
|
||||
//!
|
||||
//! Standard pre-norm transformer block (LN → attention → residual →
|
||||
//! LN → MLP → residual) where the attention slot dispatches on the
|
||||
//! per-layer `layer_types[i]` value in the config:
|
||||
//!
|
||||
//! - `"full_attention"` → [`Qwen3_5Attention`] (GQA causal + output
|
||||
//! gate + RoPE + KV cache).
|
||||
//! - `"linear_attention"` → [`GatedDeltaNet`] (recurrent delta rule +
|
||||
//! causal conv + per-head state).
|
||||
//!
|
||||
//! In Qwen3.6-27B every 4th layer is full_attention; the rest are
|
||||
//! linear_attention. `full_attention_interval` in the config is a
|
||||
//! hint; `layer_types` is authoritative.
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Module, Tensor};
|
||||
use candle_nn::var_builder::ShardedVarBuilder;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::TextConfig;
|
||||
use super::full_attn::Qwen3_5Attention;
|
||||
use super::linear_attn::GatedDeltaNet;
|
||||
use super::mlp::Qwen3_5MLP;
|
||||
use super::rmsnorm::Qwen3_5RmsNorm;
|
||||
use super::rope::RotaryEmbedding;
|
||||
use super::snapshot::LayerKvSnapshot;
|
||||
|
||||
/// One of the two attention flavours sitting in a decoder layer's
|
||||
/// attention slot. Full-attention layers need the rotary table and
|
||||
/// take an attention mask; linear-attention layers carry their own
|
||||
/// recurrent state and ignore the mask.
|
||||
enum AttentionKind {
|
||||
Full(Qwen3_5Attention),
|
||||
Linear(GatedDeltaNet),
|
||||
}
|
||||
|
||||
pub struct Qwen3_5DecoderLayer {
|
||||
input_layernorm: Qwen3_5RmsNorm,
|
||||
post_attention_layernorm: Qwen3_5RmsNorm,
|
||||
mlp: Qwen3_5MLP,
|
||||
attention: AttentionKind,
|
||||
}
|
||||
|
||||
impl Qwen3_5DecoderLayer {
|
||||
pub fn load(
|
||||
cfg: &TextConfig,
|
||||
rotary: Arc<RotaryEmbedding>,
|
||||
layer_idx: usize,
|
||||
vb: &ShardedVarBuilder,
|
||||
) -> Result<Self> {
|
||||
let layer_type = cfg
|
||||
.layer_types
|
||||
.get(layer_idx)
|
||||
.map(String::as_str)
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"layer_types[{layer_idx}] missing (have {} entries)",
|
||||
cfg.layer_types.len()
|
||||
)
|
||||
})?;
|
||||
|
||||
let attention = match layer_type {
|
||||
"full_attention" => {
|
||||
AttentionKind::Full(Qwen3_5Attention::load(cfg, rotary, &vb.pp("self_attn"))?)
|
||||
}
|
||||
"linear_attention" => {
|
||||
AttentionKind::Linear(GatedDeltaNet::load(cfg, &vb.pp("linear_attn"))?)
|
||||
}
|
||||
other => anyhow::bail!(
|
||||
"unknown layer_type '{other}' for layer {layer_idx} (expected \
|
||||
'full_attention' or 'linear_attention')"
|
||||
),
|
||||
};
|
||||
|
||||
let mlp = Qwen3_5MLP::load(cfg, &vb.pp("mlp"))?;
|
||||
let input_layernorm =
|
||||
Qwen3_5RmsNorm::load(&vb.pp("input_layernorm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||
let post_attention_layernorm = Qwen3_5RmsNorm::load(
|
||||
&vb.pp("post_attention_layernorm"),
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
input_layernorm,
|
||||
post_attention_layernorm,
|
||||
mlp,
|
||||
attention,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
attn_mask: Option<&Tensor>,
|
||||
cos: &Tensor,
|
||||
sin: &Tensor,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let h = self.input_layernorm.forward(x)?;
|
||||
let attn_out = match &mut self.attention {
|
||||
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, cos, sin)?,
|
||||
// Linear attention ignores attn_mask + rope; its causal
|
||||
// structure is baked into the recurrent state lifecycle.
|
||||
AttentionKind::Linear(net) => net.forward(&h)?,
|
||||
};
|
||||
let x = (x + attn_out)?;
|
||||
let h2 = self.post_attention_layernorm.forward(&x)?;
|
||||
let h2 = self.mlp.forward(&h2)?;
|
||||
x + h2
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
match &mut self.attention {
|
||||
AttentionKind::Full(attn) => attn.clear_kv_cache(),
|
||||
AttentionKind::Linear(net) => net.clear_kv_cache(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Capture this layer's cache state for a prefix snapshot.
|
||||
pub fn snapshot_kv(&self) -> candle_core::Result<LayerKvSnapshot> {
|
||||
Ok(match &self.attention {
|
||||
AttentionKind::Full(attn) => LayerKvSnapshot::Full(attn.snapshot_kv()),
|
||||
AttentionKind::Linear(net) => {
|
||||
let (conv_state, recurrent_state) = net.snapshot_state()?;
|
||||
LayerKvSnapshot::Linear {
|
||||
conv_state,
|
||||
recurrent_state,
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Replace this layer's cache state from a snapshot. The snapshot
|
||||
/// variant must match the layer's attention kind — a mismatch
|
||||
/// means the snapshot came from a different model.
|
||||
pub fn restore_kv(&mut self, snap: &LayerKvSnapshot) -> candle_core::Result<()> {
|
||||
match (&mut self.attention, snap) {
|
||||
(AttentionKind::Full(attn), LayerKvSnapshot::Full(kv)) => attn.restore_kv(kv.as_ref()),
|
||||
(
|
||||
AttentionKind::Linear(net),
|
||||
LayerKvSnapshot::Linear {
|
||||
conv_state,
|
||||
recurrent_state,
|
||||
},
|
||||
) => net.restore_state(conv_state.as_ref(), recurrent_state.as_ref()),
|
||||
_ => candle_core::bail!(
|
||||
"restore_kv: snapshot layer kind does not match this layer's attention kind"
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,201 +0,0 @@
|
||||
//! Qwen3-Next's `full_attention` layer.
|
||||
//!
|
||||
//! Standard GQA causal attention with two Qwen3-Next-specific quirks:
|
||||
//!
|
||||
//! 1. **Output gate (`attn_output_gate=True`).** `q_proj` is widened
|
||||
//! to `num_heads * head_dim * 2`. The second half is reshaped to
|
||||
//! `(B, L, num_heads * head_dim)` and fed through a sigmoid; the
|
||||
//! attention output is pointwise-multiplied by this gate before
|
||||
//! `o_proj`. Effectively a per-head per-position attenuation on
|
||||
//! the attention output.
|
||||
//!
|
||||
//! 2. **`(1 + w) * x` RmsNorm** on q and k (see `rmsnorm::Qwen3_5RmsNorm`).
|
||||
//! candle_nn's RmsNorm applies `w * x`; the upstream Qwen3-Next
|
||||
//! checkpoints expect the `(1 + w)` form.
|
||||
//!
|
||||
//! Otherwise: GQA with `num_attention_heads / num_key_value_heads`
|
||||
//! repeat, q_norm + k_norm on the head dim, GLM-style rotary (see
|
||||
//! `rope::RotaryEmbedding`), and the usual causal mask.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{Module, Tensor};
|
||||
use candle_nn::Linear;
|
||||
use candle_nn::kv_cache::ConcatKvCache;
|
||||
use candle_nn::var_builder::ShardedVarBuilder;
|
||||
use candle_transformers::utils::repeat_kv;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::TextConfig;
|
||||
use super::rmsnorm::Qwen3_5RmsNorm;
|
||||
use super::rope::RotaryEmbedding;
|
||||
|
||||
pub struct Qwen3_5Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
q_norm: Qwen3_5RmsNorm,
|
||||
k_norm: Qwen3_5RmsNorm,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
hidden_size: usize,
|
||||
rotary: Arc<RotaryEmbedding>,
|
||||
kv_cache: ConcatKvCache,
|
||||
}
|
||||
|
||||
impl Qwen3_5Attention {
|
||||
pub fn load(
|
||||
cfg: &TextConfig,
|
||||
rotary: Arc<RotaryEmbedding>,
|
||||
vb: &ShardedVarBuilder,
|
||||
) -> Result<Self> {
|
||||
let head_dim = cfg.head_dim;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
if num_kv_heads == 0 || !num_heads.is_multiple_of(num_kv_heads) {
|
||||
anyhow::bail!(
|
||||
"num_attention_heads ({num_heads}) must be a positive multiple of \
|
||||
num_key_value_heads ({num_kv_heads})"
|
||||
);
|
||||
}
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
|
||||
// q_proj is 2x wide: the extra `num_heads * head_dim` slice is
|
||||
// the gate (see attn_output_gate notes above).
|
||||
let q_proj = load_linear_no_bias(vb, "q_proj", cfg.hidden_size, num_heads * head_dim * 2)?;
|
||||
let k_proj = load_linear_no_bias(vb, "k_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
|
||||
let v_proj = load_linear_no_bias(vb, "v_proj", cfg.hidden_size, num_kv_heads * head_dim)?;
|
||||
let o_proj = load_linear_no_bias(vb, "o_proj", num_heads * head_dim, cfg.hidden_size)?;
|
||||
|
||||
let q_norm = Qwen3_5RmsNorm::load(&vb.pp("q_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||
let k_norm = Qwen3_5RmsNorm::load(&vb.pp("k_norm"), head_dim, cfg.rms_norm_eps)?;
|
||||
|
||||
let hidden_size = head_dim * num_heads;
|
||||
let kv_cache = ConcatKvCache::new(2);
|
||||
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
hidden_size,
|
||||
rotary,
|
||||
kv_cache,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
attn_mask: Option<&Tensor>,
|
||||
cos: &Tensor,
|
||||
sin: &Tensor,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let (b, l, _) = x.dims3()?;
|
||||
|
||||
// 1. q_proj — widened output, split into (query, gate).
|
||||
let q_raw = self
|
||||
.q_proj
|
||||
.forward(x)?
|
||||
.reshape((b, l, self.num_heads, self.head_dim * 2))?;
|
||||
let q = q_raw.narrow(3, 0, self.head_dim)?;
|
||||
let gate = q_raw.narrow(3, self.head_dim, self.head_dim)?;
|
||||
// Flatten the gate's head dim back into hidden_size for the
|
||||
// post-attention pointwise multiply.
|
||||
let gate = gate
|
||||
.contiguous()?
|
||||
.reshape((b, l, self.num_heads * self.head_dim))?;
|
||||
|
||||
// 2. q_norm + k_norm + reshape to (B, H, L, D).
|
||||
let q = self.q_norm.forward(&q.contiguous()?)?;
|
||||
let q = q.transpose(1, 2)?.contiguous()?; // (B, H, L, D)
|
||||
|
||||
let k = self
|
||||
.k_proj
|
||||
.forward(x)?
|
||||
.reshape((b, l, self.num_kv_heads, self.head_dim))?;
|
||||
let k = self.k_norm.forward(&k.contiguous()?)?;
|
||||
let k = k.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
let v = self
|
||||
.v_proj
|
||||
.forward(x)?
|
||||
.reshape((b, l, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
|
||||
// 3. RoPE on q, k (cos/sin built once per forward by the model —
|
||||
// interleaved M-RoPE for image tokens, plain for text).
|
||||
let (q, k) = self.rotary.apply_cos_sin(&q, &k, cos, sin)?;
|
||||
|
||||
// 4. KV cache.
|
||||
let (k, v) = self.kv_cache.append(&k, &v)?;
|
||||
|
||||
// 5. GQA repeat (cheap shape op).
|
||||
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
|
||||
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
|
||||
|
||||
// 6. Scaled dot-product + causal mask.
|
||||
let scale = 1.0_f64 / (self.head_dim as f64).sqrt();
|
||||
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||
if let Some(m) = attn_mask {
|
||||
scores = scores.broadcast_add(m)?;
|
||||
}
|
||||
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||
let ctx = probs.matmul(&v)?; // (B, H, L, D)
|
||||
|
||||
// 7. Reshape back, apply the output gate, project.
|
||||
let ctx = ctx
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?
|
||||
.reshape((b, l, self.hidden_size))?;
|
||||
let gate_sig = candle_nn::ops::sigmoid(&gate)?;
|
||||
let gated = (ctx * gate_sig)?;
|
||||
self.o_proj.forward(&gated)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache.reset();
|
||||
}
|
||||
|
||||
/// Capture the KV cache contents for a prefix snapshot. Shallow
|
||||
/// clones: `ConcatKvCache::append` cats into fresh allocations and
|
||||
/// never mutates stored tensors in place, so the captured tensors
|
||||
/// stay valid after the live cache moves on.
|
||||
pub fn snapshot_kv(&self) -> Option<(Tensor, Tensor)> {
|
||||
match (self.kv_cache.k(), self.kv_cache.v()) {
|
||||
(Some(k), Some(v)) => Some((k.clone(), v.clone())),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace the live KV cache with a previously captured snapshot.
|
||||
pub fn restore_kv(&mut self, snap: Option<&(Tensor, Tensor)>) -> candle_core::Result<()> {
|
||||
self.kv_cache.reset();
|
||||
if let Some((k, v)) = snap {
|
||||
self.kv_cache.append(k, v)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn load_linear_no_bias(
|
||||
vb: &ShardedVarBuilder,
|
||||
name: &str,
|
||||
in_dim: usize,
|
||||
out_dim: usize,
|
||||
) -> Result<Linear> {
|
||||
let weight = vb
|
||||
.pp(name)
|
||||
.get((out_dim, in_dim), "weight")
|
||||
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
|
||||
Ok(Linear::new(weight, None))
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,53 +0,0 @@
|
||||
//! SwiGLU MLP block for Qwen3-Next.
|
||||
//!
|
||||
//! Identical to plain Qwen3's MLP: `down(silu(gate(x)) * up(x))` with
|
||||
//! no bias on any of the three projections.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{Module, Tensor};
|
||||
use candle_nn::Linear;
|
||||
use candle_nn::var_builder::ShardedVarBuilder;
|
||||
|
||||
use super::TextConfig;
|
||||
|
||||
pub struct Qwen3_5MLP {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
}
|
||||
|
||||
impl Qwen3_5MLP {
|
||||
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let i = cfg.intermediate_size;
|
||||
let gate_proj = load_linear_no_bias(vb, "gate_proj", h, i)?;
|
||||
let up_proj = load_linear_no_bias(vb, "up_proj", h, i)?;
|
||||
let down_proj = load_linear_no_bias(vb, "down_proj", i, h)?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Qwen3_5MLP {
|
||||
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||
let lhs = candle_nn::ops::silu(&self.gate_proj.forward(x)?)?;
|
||||
let rhs = self.up_proj.forward(x)?;
|
||||
self.down_proj.forward(&(lhs * rhs)?)
|
||||
}
|
||||
}
|
||||
|
||||
fn load_linear_no_bias(
|
||||
vb: &ShardedVarBuilder,
|
||||
name: &str,
|
||||
in_dim: usize,
|
||||
out_dim: usize,
|
||||
) -> Result<Linear> {
|
||||
let weight = vb
|
||||
.pp(name)
|
||||
.get((out_dim, in_dim), "weight")
|
||||
.with_context(|| format!("load '{}/{name}/weight'", vb.prefix()))?;
|
||||
Ok(Linear::new(weight, None))
|
||||
}
|
||||
@@ -1,965 +0,0 @@
|
||||
//! Qwen3-Next (`model_type = "qwen3_5"`) architecture — Qwen3.6's
|
||||
//! upstream architecture revision.
|
||||
//!
|
||||
//! ## Naming
|
||||
//!
|
||||
//! The model release this targets is `Qwen/Qwen3.6-*` but the
|
||||
//! architecture name in HuggingFace's `config.json` is `qwen3_5`.
|
||||
//! mistralrs calls the same architecture `qwen3_next`; that label
|
||||
//! ages poorly the next time Qwen ship a new arch, so we key on the
|
||||
//! canonical `qwen3_5` from the model's own config.
|
||||
//!
|
||||
//! ## Status
|
||||
//!
|
||||
//! **Single-GPU dense path is real**. Both attention flavours
|
||||
//! (`full_attention` with the output-gated GQA causal attention and
|
||||
//! `linear_attention` with the Gated DeltaNet recurrent block) are
|
||||
//! implemented. The model loads from upstream safetensors via the
|
||||
//! existing `load_arch_dense` dispatch and runs forward end to end.
|
||||
//!
|
||||
//! Numerical correctness vs the reference Python is **not yet
|
||||
//! validated** — the structural code path is right, weight tensor
|
||||
//! names match the upstream layout, shapes flow through cleanly, but
|
||||
//! the Tbilisi probe (and any other downstream test) is the next
|
||||
//! step. Likely places a bug would surface:
|
||||
//! - Per-rank vs per-token-position offsets in the recurrent delta
|
||||
//! rule (`linear_attn.rs`).
|
||||
//! - Off-by-one in the conv state continuation across decode steps.
|
||||
//! - RoPE phase mismatch from MRoPE simplification (we treat the
|
||||
//! three position grids as collapsed, which is correct only for
|
||||
//! text-only inference).
|
||||
//!
|
||||
//! ## Submodules
|
||||
//!
|
||||
//! - [`rmsnorm`] — `Qwen3_5RmsNorm` (`(1+w)*x` variant), the
|
||||
//! `Qwen3_5RmsNormGated` used after the delta rule, and the
|
||||
//! `l2norm` helper.
|
||||
//! - [`rope`] — text-side rotary embedding (mrope simplified, GLM
|
||||
//! rotate-half).
|
||||
//! - [`mlp`] — SwiGLU MLP (gate/up/down, no bias).
|
||||
//! - [`full_attn`] — `Qwen3_5Attention` with the output-gate
|
||||
//! widening on `q_proj`.
|
||||
//! - [`linear_attn`] — `GatedDeltaNet` recurrent delta-rule block
|
||||
//! (causal depthwise Conv1d → silu → split → L2norm → per-token
|
||||
//! delta rule → RMSNormGated → out_proj).
|
||||
//! - [`decoder`] — `Qwen3_5DecoderLayer` dispatching to one of the
|
||||
//! two attention flavours per layer index.
|
||||
//!
|
||||
//! ## Open work
|
||||
//!
|
||||
//! - **TP variant.** `harness/tp/tp_qwen3_5.rs` is the next step.
|
||||
//! Sharding strategy diverges by layer type:
|
||||
//! - Full-attention layers: column-parallel q/k/v (including the
|
||||
//! gate half of `q_proj`) + row-parallel `o_proj`, mirroring
|
||||
//! `tp_qwen3.rs`.
|
||||
//! - Linear-attention layers: the recurrent state is per-V-head, so
|
||||
//! V-head-dimension sharding works cleanly — split `num_v_heads`
|
||||
//! across ranks (`num_v_heads / world_size` per rank), shard
|
||||
//! `in_proj_qkv` / `in_proj_z` / `in_proj_b` / `in_proj_a` along
|
||||
//! the V-head dim, and row-parallel `out_proj`. The `A_log` /
|
||||
//! `dt_bias` per-head params shard with the heads.
|
||||
//!
|
||||
//! - **Chunked delta-rule prefill.** `linear_attn.rs` runs the
|
||||
//! per-token recurrent path for prefill too — correct but O(L).
|
||||
//! Porting `torch_chunk_gated_delta_rule` (chunk_size=64) speeds
|
||||
//! prefill substantially with no surface change.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{DType, Device, IndexOp, Module, Tensor};
|
||||
use candle_nn::Embedding;
|
||||
use candle_nn::Linear;
|
||||
use candle_nn::var_builder::ShardedVarBuilder;
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub mod decoder;
|
||||
pub mod full_attn;
|
||||
pub mod linear_attn;
|
||||
pub mod mlp;
|
||||
pub mod rmsnorm;
|
||||
pub mod rope;
|
||||
pub mod snapshot;
|
||||
pub mod vision;
|
||||
|
||||
use decoder::Qwen3_5DecoderLayer;
|
||||
use rmsnorm::Qwen3_5RmsNorm;
|
||||
use rope::RotaryEmbedding;
|
||||
|
||||
/// `model_type` we deserialise from `config.json`. Const so the
|
||||
/// dispatch in `candle.rs::load_arch_dense` can pattern-match without
|
||||
/// magic strings.
|
||||
pub const MODEL_TYPE: &str = "qwen3_5";
|
||||
|
||||
/// Top-level shape of Qwen3-Next's `config.json`. The real
|
||||
/// hyperparameters live in `text_config`; the rest is multimodal /
|
||||
/// tokeniser glue we don't need for the language-model forward.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Config {
|
||||
/// Always `"qwen3_5"` for this architecture. Kept on the struct
|
||||
/// so the (eventual) dispatch / logging code can show it without
|
||||
/// re-parsing the JSON.
|
||||
pub model_type: String,
|
||||
/// The text-side hyperparameters. Everything we actually need.
|
||||
pub text_config: TextConfig,
|
||||
/// Vision tower hyperparameters. Present on multimodal
|
||||
/// checkpoints (e.g. Qwen/Qwen3.6-27B); absent on text-only
|
||||
/// variants. When present, `Qwen3_5ForCausalLM::new` loads the
|
||||
/// vision tower alongside the language model so vision-bearing
|
||||
/// requests can splice image embeddings at `<|image_pad|>` token
|
||||
/// positions.
|
||||
#[serde(default)]
|
||||
pub vision_config: Option<vision::VisionConfig>,
|
||||
/// Token id the chat template emits per image patch group.
|
||||
/// Mirrors the LM tokenizer's `<|image_pad|>` id (248056 for
|
||||
/// Qwen3.6). The runtime locates these in the prompt and splices
|
||||
/// in `VisionTower::forward` output. `None` for text-only models.
|
||||
#[serde(default)]
|
||||
pub image_token_id: Option<u32>,
|
||||
}
|
||||
|
||||
/// Inner config (the `text_config` block). Mirrors the Qwen3 layout
|
||||
/// but with the extras Qwen3-Next adds (`attn_output_gate`,
|
||||
/// `layer_types`, `full_attention_interval`, larger `head_dim`).
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct TextConfig {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub head_dim: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
/// Nested RoPE settings. Qwen3-Next puts `rope_theta` and
|
||||
/// `partial_rotary_factor` inside this block rather than at the
|
||||
/// top level — important because the partial rotary means only
|
||||
/// `head_dim * partial_rotary_factor` dims get RoPE applied (the
|
||||
/// rest pass through unchanged).
|
||||
pub rope_parameters: RopeParameters,
|
||||
pub rms_norm_eps: f64,
|
||||
#[serde(default)]
|
||||
pub tie_word_embeddings: bool,
|
||||
|
||||
/// New in Qwen3-Next: a sigmoid gate multiplied into the attention
|
||||
/// output before the o_proj. The Python reference applies it
|
||||
/// pointwise after softmax+matmul.
|
||||
#[serde(default)]
|
||||
pub attn_output_gate: bool,
|
||||
|
||||
/// One entry per decoder layer; values are `"full_attention"` or
|
||||
/// `"linear_attention"`. Length must equal `num_hidden_layers`.
|
||||
/// `full_attention_interval` is a derived hint (every 4th layer
|
||||
/// by default) — `layer_types` is authoritative.
|
||||
#[serde(default)]
|
||||
pub layer_types: Vec<String>,
|
||||
|
||||
/// Hint for the layer-type pattern (defaults to 4). Kept for
|
||||
/// logging / validation; the forward dispatches on `layer_types`.
|
||||
#[serde(default)]
|
||||
pub full_attention_interval: Option<usize>,
|
||||
|
||||
/// Hidden activation (`"silu"` for Qwen3-Next). Used by the MLP
|
||||
/// and the linear-attention conv1d.
|
||||
#[serde(default = "default_hidden_act")]
|
||||
pub hidden_act: String,
|
||||
|
||||
// --- Gated DeltaNet (linear-attention) hyperparams -----------------
|
||||
/// Per-layer linear-attention V-head count (Qwen3.6-27B: 48).
|
||||
/// More V-heads than K-heads is fine — query/key get
|
||||
/// `repeat_interleave`'d to match before the delta rule.
|
||||
#[serde(default)]
|
||||
pub linear_num_value_heads: usize,
|
||||
/// Per-layer linear-attention K-head count (Qwen3.6-27B: 16).
|
||||
#[serde(default)]
|
||||
pub linear_num_key_heads: usize,
|
||||
/// Per-head key dimension for the linear-attention path
|
||||
/// (Qwen3.6-27B: 128). Separate from `head_dim` which the
|
||||
/// full-attention layers use.
|
||||
#[serde(default)]
|
||||
pub linear_key_head_dim: usize,
|
||||
/// Per-head value dimension for the linear-attention path
|
||||
/// (Qwen3.6-27B: 128).
|
||||
#[serde(default)]
|
||||
pub linear_value_head_dim: usize,
|
||||
/// Causal Conv1d kernel size used before the delta rule
|
||||
/// (Qwen3.6-27B: 4).
|
||||
#[serde(default)]
|
||||
pub linear_conv_kernel_dim: usize,
|
||||
}
|
||||
|
||||
fn default_hidden_act() -> String {
|
||||
"silu".into()
|
||||
}
|
||||
|
||||
/// Nested `rope_parameters` block from a Qwen3-Next `config.json`.
|
||||
///
|
||||
/// For text-only inference the three MRoPE position grids carry
|
||||
/// identical ids, so the interleave is a no-op and plain RoPE applies.
|
||||
/// For vision inputs `mrope_section` + `mrope_interleaved` drive the
|
||||
/// per-axis (text/height/width) rotary used by image tokens — see
|
||||
/// `rope.rs`.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct RopeParameters {
|
||||
/// Base for the inverse-frequency computation. Qwen3.6: 10_000_000.
|
||||
#[serde(default = "default_rope_theta")]
|
||||
pub rope_theta: f64,
|
||||
/// Fraction of `head_dim` that gets the rotation applied. The
|
||||
/// remaining `head_dim * (1 - partial_rotary_factor)` dims pass
|
||||
/// through unchanged. Qwen3.6 / Qwen3.5: 0.25.
|
||||
#[serde(default = "default_partial_rotary_factor")]
|
||||
pub partial_rotary_factor: f32,
|
||||
/// `"default"` for the standard inv_freq RoPE; other values (e.g.
|
||||
/// `"linear"`, `"dynamic"`) are upstream-supported but not yet
|
||||
/// implemented here.
|
||||
#[serde(default)]
|
||||
pub rope_type: Option<String>,
|
||||
/// MRoPE per-axis section sizes `[text, height, width]` — e.g.
|
||||
/// `[11, 11, 10]` for Qwen3.6, summing to the rotary half-dim.
|
||||
/// Empty for models that don't declare MRoPE (→ plain RoPE).
|
||||
#[serde(default)]
|
||||
pub mrope_section: Vec<usize>,
|
||||
/// Whether the three MRoPE axes are interleaved per-frequency
|
||||
/// (Qwen3-VL / Qwen3.6 style, `true`) rather than block-concatenated
|
||||
/// (Qwen2-VL style, `false`).
|
||||
#[serde(default)]
|
||||
pub mrope_interleaved: bool,
|
||||
}
|
||||
|
||||
fn default_rope_theta() -> f64 {
|
||||
10_000.0
|
||||
}
|
||||
|
||||
fn default_partial_rotary_factor() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
/// Splice rows from `img` into `h` at `positions`. Stage B helper.
|
||||
///
|
||||
/// `h`: `(1, L, hidden)` — the LM's input embedding tensor after
|
||||
/// `embed_tokens.forward`.
|
||||
/// `img`: `(N_img, hidden)` — image embeddings, one row per
|
||||
/// `<|image_pad|>` token in the prompt. Must already be in `h.dtype()`.
|
||||
/// `positions`: indices into the `L` axis where image rows go;
|
||||
/// `positions.len() == N_img`.
|
||||
///
|
||||
/// Approach: group `positions` into contiguous runs (because the chat
|
||||
/// template emits `<|vision_start|><|image_pad|>×N<|vision_end|>` —
|
||||
/// the pad tokens for each image land in one contiguous span), then
|
||||
/// `slice_assign` per run. For typical Qwen3.6 requests this is one
|
||||
/// or two runs per image; `slice_assign` does one tensor copy per
|
||||
/// run, which is cheap relative to the decoder forward pass.
|
||||
pub(crate) fn splice_runs(
|
||||
h: &Tensor,
|
||||
img: &Tensor,
|
||||
positions: &[u32],
|
||||
) -> candle_core::Result<Tensor> {
|
||||
debug_assert!(
|
||||
!positions.is_empty(),
|
||||
"splice_runs precondition: non-empty positions"
|
||||
);
|
||||
let hidden = h.dim(2)?;
|
||||
let mut out = h.clone();
|
||||
let mut img_offset = 0_usize;
|
||||
let mut run_start = positions[0] as usize;
|
||||
let mut run_end_exclusive = run_start + 1;
|
||||
for &p in &positions[1..] {
|
||||
let p = p as usize;
|
||||
if p == run_end_exclusive {
|
||||
run_end_exclusive = p + 1;
|
||||
} else {
|
||||
apply_run(
|
||||
&mut out,
|
||||
img,
|
||||
&mut img_offset,
|
||||
run_start,
|
||||
run_end_exclusive,
|
||||
hidden,
|
||||
)?;
|
||||
run_start = p;
|
||||
run_end_exclusive = p + 1;
|
||||
}
|
||||
}
|
||||
apply_run(
|
||||
&mut out,
|
||||
img,
|
||||
&mut img_offset,
|
||||
run_start,
|
||||
run_end_exclusive,
|
||||
hidden,
|
||||
)?;
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn apply_run(
|
||||
out: &mut Tensor,
|
||||
img: &Tensor,
|
||||
img_offset: &mut usize,
|
||||
run_start: usize,
|
||||
run_end_exclusive: usize,
|
||||
hidden: usize,
|
||||
) -> candle_core::Result<()> {
|
||||
let run_len = run_end_exclusive - run_start;
|
||||
let slice = img
|
||||
.narrow(0, *img_offset, run_len)?
|
||||
.reshape((1, run_len, hidden))?;
|
||||
*out = out.slice_assign(&[0..1, run_start..run_end_exclusive, 0..hidden], &slice)?;
|
||||
*img_offset += run_len;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Qwen3-Next base transformer (embedding + decoder stack + final
|
||||
/// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can
|
||||
/// also build on it later — for now only `Qwen3_5ForCausalLM` is the
|
||||
/// loaded handle.
|
||||
pub struct Qwen3_5Model {
|
||||
embed_tokens: Embedding,
|
||||
layers: Vec<Qwen3_5DecoderLayer>,
|
||||
norm: Qwen3_5RmsNorm,
|
||||
/// Shared with every full-attention layer; the model uses it to
|
||||
/// build the per-forward cos/sin (interleaved M-RoPE for image
|
||||
/// tokens, plain for text) once, which the layers then apply.
|
||||
rotary: Arc<RotaryEmbedding>,
|
||||
/// `offset + rope_delta` is the text-axis position during decode.
|
||||
/// 0 for text-only; set from `get_rope_index` during a vision
|
||||
/// prefill (image tokens compress the position space, so text after
|
||||
/// the image resumes from a smaller counter than the sequence
|
||||
/// index). Reset in `clear_kv_cache`.
|
||||
rope_delta: i64,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Qwen3_5Model {
|
||||
pub fn load(cfg: &TextConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||
let dtype = vb.dtype();
|
||||
let device = vb.device().clone();
|
||||
|
||||
// Qwen3-Next is a multimodal architecture whose text core lives
|
||||
// under `model.language_model.*` — sibling to `model.visual.*`
|
||||
// (the vision tower) and to top-level `lm_head` / `mtp.*`.
|
||||
// Every text-side tensor in the safetensors files is under
|
||||
// this prefix; we ignore the vision and MTP weights for
|
||||
// language-model inference.
|
||||
let text_vb = vb.pp("model.language_model");
|
||||
|
||||
let embed_vb = text_vb.pp("embed_tokens");
|
||||
let embed_weight = embed_vb
|
||||
.get((cfg.vocab_size, cfg.hidden_size), "weight")
|
||||
.with_context(|| format!("load '{}/weight'", embed_vb.prefix()))?;
|
||||
let embed_tokens = Embedding::new(embed_weight, cfg.hidden_size);
|
||||
|
||||
let rotary = Arc::new(RotaryEmbedding::new(dtype, cfg, &device)?);
|
||||
|
||||
if cfg.layer_types.len() != cfg.num_hidden_layers {
|
||||
anyhow::bail!(
|
||||
"config.text_config.layer_types must have num_hidden_layers ({}) entries; \
|
||||
got {}",
|
||||
cfg.num_hidden_layers,
|
||||
cfg.layer_types.len()
|
||||
);
|
||||
}
|
||||
|
||||
let vb_l = text_vb.pp("layers");
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
for i in 0..cfg.num_hidden_layers {
|
||||
layers.push(Qwen3_5DecoderLayer::load(
|
||||
cfg,
|
||||
rotary.clone(),
|
||||
i,
|
||||
&vb_l.pp(i),
|
||||
)?);
|
||||
}
|
||||
|
||||
let norm = Qwen3_5RmsNorm::load(&text_vb.pp("norm"), cfg.hidden_size, cfg.rms_norm_eps)?;
|
||||
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
rotary,
|
||||
rope_delta: 0,
|
||||
device,
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn embed_weight(&self) -> &Tensor {
|
||||
self.embed_tokens.embeddings()
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for l in &mut self.layers {
|
||||
l.clear_kv_cache();
|
||||
}
|
||||
// New request → no image-compressed position offset until the
|
||||
// next vision prefill sets one.
|
||||
self.rope_delta = 0;
|
||||
}
|
||||
|
||||
/// Capture every layer's cache state plus the rope position
|
||||
/// counter as one consistent prefix snapshot (#11). Only valid at
|
||||
/// a token boundary — i.e. between forward calls, which is the
|
||||
/// only time the caller can reach this anyway.
|
||||
pub fn snapshot_kv_cache(&self) -> candle_core::Result<snapshot::KvCacheSnapshot> {
|
||||
let layers = self
|
||||
.layers
|
||||
.iter()
|
||||
.map(|l| l.snapshot_kv())
|
||||
.collect::<candle_core::Result<Vec<_>>>()?;
|
||||
Ok(snapshot::KvCacheSnapshot {
|
||||
layers,
|
||||
rope_delta: self.rope_delta,
|
||||
})
|
||||
}
|
||||
|
||||
/// Replace the live cache state with a previously captured
|
||||
/// snapshot. The snapshot stays valid for further restores.
|
||||
pub fn restore_kv_cache(
|
||||
&mut self,
|
||||
snap: &snapshot::KvCacheSnapshot,
|
||||
) -> candle_core::Result<()> {
|
||||
if snap.layers.len() != self.layers.len() {
|
||||
candle_core::bail!(
|
||||
"restore_kv_cache: snapshot has {} layers, model has {}",
|
||||
snap.layers.len(),
|
||||
self.layers.len()
|
||||
);
|
||||
}
|
||||
for (layer, layer_snap) in self.layers.iter_mut().zip(snap.layers.iter()) {
|
||||
layer.restore_kv(layer_snap)?;
|
||||
}
|
||||
self.rope_delta = snap.rope_delta;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
||||
let minf = f32::NEG_INFINITY;
|
||||
let mask: Vec<_> = (0..tgt)
|
||||
.flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||
self.forward_inner(input, offset, None, None, &[], None)
|
||||
}
|
||||
|
||||
/// Forward for a vision-prefill chunk: optional image-embedding
|
||||
/// splice plus explicit interleaved-M-RoPE `position_ids` (the
|
||||
/// chunk's slice of the full prompt's 3D positions). Mirrors the TP
|
||||
/// `TpQwen3_5Model::forward_with_positions` — used by
|
||||
/// `Qwen3_5ForCausalLM::prefill_with_images_chunked`, which computes
|
||||
/// the positions once over the whole prompt and slices them per
|
||||
/// chunk so the position counters stay consistent across chunk
|
||||
/// boundaries (an image compresses the position space, so per-chunk
|
||||
/// offset arithmetic would be wrong).
|
||||
pub fn forward_with_positions(
|
||||
&mut self,
|
||||
input: &Tensor,
|
||||
offset: usize,
|
||||
position_ids: &Tensor,
|
||||
image_embeds: Option<&Tensor>,
|
||||
image_token_id: Option<u32>,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
self.forward_inner(
|
||||
input,
|
||||
offset,
|
||||
image_embeds,
|
||||
image_token_id,
|
||||
&[],
|
||||
Some(position_ids),
|
||||
)
|
||||
}
|
||||
|
||||
/// Forward with image-embedding splice. Stage B of the vision plan.
|
||||
///
|
||||
/// `input_ids`: `(1, L)` token ids — same shape the text-only
|
||||
/// `forward` accepts (single-batch; multi-batch vision is not in
|
||||
/// scope today).
|
||||
/// `image_embeds`: `(N_image_tokens, hidden_size)` — concatenation
|
||||
/// of every image's post-merger embedding (`VisionTower::forward`
|
||||
/// output), in the same order images appear in the input. The
|
||||
/// caller has already done the per-image patch-count expansion of
|
||||
/// `<|image_pad|>` tokens in `input_ids`, so `N_image_tokens`
|
||||
/// equals the number of `image_token_id` positions in `input_ids`.
|
||||
/// `image_token_id`: the sentinel token (e.g. 248056 for Qwen3.6).
|
||||
///
|
||||
/// The splice replaces the LM's text-side embedding at each
|
||||
/// `image_token_id` position with the corresponding row from
|
||||
/// `image_embeds`. After the splice the decoder runs the interleaved
|
||||
/// M-RoPE path: `grids` carries each image's post-merge LM grid
|
||||
/// `(lm_gh, lm_gw)` so `get_rope_index` assigns image tokens their 2D
|
||||
/// coordinates (dynamic resolution, #14).
|
||||
pub fn forward_with_vision(
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
offset: usize,
|
||||
image_embeds: &Tensor,
|
||||
image_token_id: u32,
|
||||
grids: &[(usize, usize)],
|
||||
) -> candle_core::Result<Tensor> {
|
||||
self.forward_inner(
|
||||
input_ids,
|
||||
offset,
|
||||
Some(image_embeds),
|
||||
Some(image_token_id),
|
||||
grids,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
/// Shared forward. Splices image embeddings at `image_token_id`
|
||||
/// positions when present, then builds the rotary cos/sin, in
|
||||
/// precedence order: explicit `position_ids` (interleaved M-RoPE,
|
||||
/// the chunked-vision path that slices a once-computed position
|
||||
/// tensor) > internal M-RoPE from `grids` (single-shot vision) >
|
||||
/// plain positions at `offset + rope_delta` (text / decode).
|
||||
fn forward_inner(
|
||||
&mut self,
|
||||
input: &Tensor,
|
||||
offset: usize,
|
||||
image_embeds: Option<&Tensor>,
|
||||
image_token_id: Option<u32>,
|
||||
grids: &[(usize, usize)],
|
||||
position_ids: Option<&Tensor>,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let (b, l) = input.dims2()?;
|
||||
let mut h = self.embed_tokens.forward(input)?;
|
||||
|
||||
// Splice image embeddings at `image_token_id` positions, when
|
||||
// this forward carries any. Independent of how cos/sin is built.
|
||||
if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) {
|
||||
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
|
||||
let mut positions: Vec<u32> = Vec::with_capacity(img.dim(0)?);
|
||||
for (idx, id) in ids.iter().enumerate() {
|
||||
if *id == tok_id {
|
||||
positions.push(idx as u32);
|
||||
}
|
||||
}
|
||||
let n_img_tokens = img.dim(0)?;
|
||||
if positions.len() != n_img_tokens {
|
||||
candle_core::bail!(
|
||||
"forward_with_vision: chunk has {} image-token positions but \
|
||||
image_embeds carries {} tokens — per-image patch-count expansion \
|
||||
/ chunk slicing mismatch",
|
||||
positions.len(),
|
||||
n_img_tokens,
|
||||
);
|
||||
}
|
||||
if !positions.is_empty() {
|
||||
// Cast image_embeds to the LM's dtype, then splice the
|
||||
// contiguous `<|image_pad|>` runs in place.
|
||||
let img = img.to_dtype(self.dtype)?;
|
||||
h = splice_runs(&h, &img, &positions)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Build interleaved M-RoPE cos/sin so image tokens carry their
|
||||
// 2D (lm_gh × lm_gw) grid coordinates. Text / decode take the
|
||||
// plain-RoPE fast path — bit-for-bit the pre-M-RoPE behaviour
|
||||
// when `rope_delta == 0`.
|
||||
let (cos, sin) = if let Some(pos) = position_ids {
|
||||
// Pre-computed positions sliced for this chunk — the splice
|
||||
// above already advanced `rope_delta`'s effect into `pos`.
|
||||
self.rotary.mrope_cos_sin(pos)?
|
||||
} else if let Some(tok_id) = image_token_id {
|
||||
// Single-shot vision: compute the whole prompt's M-RoPE here
|
||||
// and stash `rope_delta` for the decode that follows.
|
||||
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
|
||||
let (text, height, width, delta) = rope::get_rope_index(&ids, tok_id, grids)
|
||||
.map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?;
|
||||
self.rope_delta = delta;
|
||||
let pos = rope::mrope_position_tensor(&text, &height, &width, &self.device)?;
|
||||
self.rotary.mrope_cos_sin(&pos)?
|
||||
} else {
|
||||
let base = (offset as i64 + self.rope_delta).max(0) as usize;
|
||||
self.rotary.plain_cos_sin(base, l)?
|
||||
};
|
||||
|
||||
// Causal mask only needed for L > 1 prefill; full-attention
|
||||
// layers consume it via broadcast_add. Linear-attention layers
|
||||
// ignore the mask.
|
||||
let causal = if l == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(self.causal_mask(b, l, offset)?)
|
||||
};
|
||||
for layer in &mut self.layers {
|
||||
h = layer.forward(&h, causal.as_ref(), &cos, &sin)?;
|
||||
}
|
||||
self.norm.forward(&h)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Qwen3_5ForCausalLM {
|
||||
base: Qwen3_5Model,
|
||||
lm_head: Linear,
|
||||
/// Vision tower (Stage A4). `None` for text-only checkpoints or
|
||||
/// when the operator has opted out. When present, the harness's
|
||||
/// `Job::EncodeImage` dispatch path runs `vision.forward(image)`
|
||||
/// and the LM forward (Stage B) splices the result at
|
||||
/// `image_token_id` positions in the input embedding stream.
|
||||
vision: Option<vision::VisionTower>,
|
||||
/// Mirrors `Config::image_token_id`. Cached here so the runtime
|
||||
/// doesn't have to round-trip through the parsed config struct.
|
||||
image_token_id: Option<u32>,
|
||||
}
|
||||
|
||||
impl Qwen3_5ForCausalLM {
|
||||
pub fn new(config: Config, vb: ShardedVarBuilder) -> Result<Self> {
|
||||
let cfg = &config.text_config;
|
||||
let base = Qwen3_5Model::load(cfg, &vb)?;
|
||||
let lm_head = if cfg.tie_word_embeddings {
|
||||
Linear::new(base.embed_weight().clone(), None)
|
||||
} else {
|
||||
let weight = vb
|
||||
.pp("lm_head")
|
||||
.get((cfg.vocab_size, cfg.hidden_size), "weight")
|
||||
.with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?;
|
||||
Linear::new(weight, None)
|
||||
};
|
||||
// Stage A4: load the vision tower when the config carries a
|
||||
// `vision_config` block and the safetensors actually carry
|
||||
// `model.visual.*` weights. The `Option<VisionConfig>` on the
|
||||
// config makes this a single-source-of-truth decision —
|
||||
// text-only checkpoints just leave `vision_config` unset and
|
||||
// get `None` here without any extra plumbing.
|
||||
let vision = if let Some(vcfg) = config.vision_config.clone() {
|
||||
tracing::info!(
|
||||
depth = vcfg.depth,
|
||||
hidden_size = vcfg.hidden_size,
|
||||
"loading qwen3_5 vision tower"
|
||||
);
|
||||
Some(
|
||||
vision::VisionTower::load(vcfg, vb.pp("model.visual"))
|
||||
.context("load qwen3_5 vision tower (model.visual.*)")?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
base,
|
||||
lm_head,
|
||||
vision,
|
||||
image_token_id: config.image_token_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// True when this checkpoint loaded a vision tower. Used by the
|
||||
/// HTTP layer to advertise vision capability in `/v1/models` and
|
||||
/// to reject image-bearing requests against text-only loads with
|
||||
/// a clean 400.
|
||||
pub fn has_vision(&self) -> bool {
|
||||
self.vision.is_some()
|
||||
}
|
||||
|
||||
/// Vision tower handle, if loaded. The device-worker
|
||||
/// `EncodeImage` job dispatches to `vision.forward(image)`.
|
||||
pub fn vision(&self) -> Option<&vision::VisionTower> {
|
||||
self.vision.as_ref()
|
||||
}
|
||||
|
||||
/// `<|image_pad|>` token id from `config.json`, when known.
|
||||
/// The Stage B prompt-builder uses this to count expansion targets
|
||||
/// and the LM forward uses it to locate splice positions.
|
||||
pub fn image_token_id(&self) -> Option<u32> {
|
||||
self.image_token_id
|
||||
}
|
||||
|
||||
/// `input`: token-id tensor of shape `(B, L)`. Returns logits at
|
||||
/// the last position, shape `(B, 1, vocab_size)` — same contract
|
||||
/// as `qwen3::ModelForCausalLM::forward` so the harness's
|
||||
/// `squeeze_to_vocab` helper handles both uniformly.
|
||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||
let (_, l) = input.dims2()?;
|
||||
let hidden = self.base.forward(input, offset)?;
|
||||
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
/// Stage B: forward with image-embedding splice. Mirrors `forward`
|
||||
/// but routes through `Qwen3_5Model::forward_with_vision` so the
|
||||
/// LM's input embeddings get the image patches spliced in at
|
||||
/// `image_token_id` positions before the decoder stack runs.
|
||||
pub fn forward_with_vision(
|
||||
&mut self,
|
||||
input: &Tensor,
|
||||
offset: usize,
|
||||
image_embeds: &Tensor,
|
||||
image_token_id: u32,
|
||||
grids: &[(usize, usize)],
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let (_, l) = input.dims2()?;
|
||||
let hidden =
|
||||
self.base
|
||||
.forward_with_vision(input, offset, image_embeds, image_token_id, grids)?;
|
||||
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
/// Forward for a vision-prefill chunk: explicit M-RoPE positions +
|
||||
/// optional image splice. Mirrors `forward_with_vision` but routes
|
||||
/// through `Qwen3_5Model::forward_with_positions`. Used by
|
||||
/// [`Self::prefill_with_images_chunked`].
|
||||
pub fn forward_with_positions(
|
||||
&mut self,
|
||||
input: &Tensor,
|
||||
offset: usize,
|
||||
position_ids: &Tensor,
|
||||
image_embeds: Option<&Tensor>,
|
||||
image_token_id: Option<u32>,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let (_, l) = input.dims2()?;
|
||||
let hidden = self.base.forward_with_positions(
|
||||
input,
|
||||
offset,
|
||||
position_ids,
|
||||
image_embeds,
|
||||
image_token_id,
|
||||
)?;
|
||||
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
/// Encode every preprocessed `(C, H, W)` image once through the
|
||||
/// vision tower and concatenate along the patch axis →
|
||||
/// `(sum_patches, hidden)`. Done once per prefill, not per chunk.
|
||||
fn encode_images_concat(&self, image_pixels: &[Tensor]) -> candle_core::Result<Tensor> {
|
||||
let tower = self.vision.as_ref().ok_or_else(|| {
|
||||
candle_core::Error::Msg(
|
||||
"encode_images_concat: loaded without a vision tower \
|
||||
(config.json::vision_config absent or weights missing)"
|
||||
.into(),
|
||||
)
|
||||
})?;
|
||||
let mut per_image = Vec::with_capacity(image_pixels.len());
|
||||
for (idx, img) in image_pixels.iter().enumerate() {
|
||||
let embed = tower
|
||||
.forward(img)
|
||||
.map_err(|e| candle_core::Error::Msg(format!("encode image[{idx}]: {e:#}")))?;
|
||||
per_image.push(embed);
|
||||
}
|
||||
Tensor::cat(&per_image.iter().collect::<Vec<_>>(), 0)
|
||||
}
|
||||
|
||||
/// Chunked image prefill for the single-GPU path (#18) — parity with
|
||||
/// `TpQwen3_5ForCausalLM::prefill_with_images_chunked`. Encodes the
|
||||
/// image(s) once, then walks the (pre-expanded) prompt in
|
||||
/// `chunk_size`-token windows — exactly like the text
|
||||
/// `chunked_prefill_*` paths — splicing the patch embeddings into
|
||||
/// whichever chunk(s) carry `<|image_pad|>` positions. Activation
|
||||
/// memory is bounded by the chunk, not the full prompt, so a long
|
||||
/// vision context no longer single-shot-OOMs.
|
||||
///
|
||||
/// The KV cache (and GDN recurrent state) accumulate across chunks
|
||||
/// via the growing offset — the same per-chunk associativity the
|
||||
/// text chunked prefill and prefix cache (#11/#23) rely on. Only the
|
||||
/// final chunk's last-position logits are returned; intermediate
|
||||
/// chunks just populate the cache. The caller is responsible for
|
||||
/// clearing the cache first.
|
||||
///
|
||||
/// `base_offset` is the KV position the prefill starts at (0 for a
|
||||
/// fresh request). `image_pixels` are device-resident `(C, H, W)`
|
||||
/// tensors; grids and the interleaved-M-RoPE position ids are
|
||||
/// recomputed here so an image's position compression is consistent
|
||||
/// across chunk boundaries.
|
||||
pub fn prefill_with_images_chunked(
|
||||
&mut self,
|
||||
tokens: &[u32],
|
||||
base_offset: usize,
|
||||
image_pixels: &[Tensor],
|
||||
image_token_id: u32,
|
||||
chunk_size: usize,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
if image_pixels.is_empty() {
|
||||
candle_core::bail!("prefill_with_images_chunked: called with zero images");
|
||||
}
|
||||
if tokens.is_empty() {
|
||||
candle_core::bail!("prefill_with_images_chunked: empty prompt");
|
||||
}
|
||||
let chunk_size = chunk_size.max(1);
|
||||
let device = self.base.device.clone();
|
||||
|
||||
let image_embeds = self.encode_images_concat(image_pixels)?;
|
||||
|
||||
// Each image's LM grid (lm_gh, lm_gw) = (h/factor, w/factor),
|
||||
// factor = patch×merge — recomputed from the pixel tensors (#14
|
||||
// dynamic resolution).
|
||||
let factor = self
|
||||
.vision
|
||||
.as_ref()
|
||||
.map(|v| {
|
||||
let c = v.config();
|
||||
c.patch_size * c.spatial_merge_size
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
candle_core::Error::Msg(
|
||||
"prefill_with_images_chunked: loaded without a vision tower".into(),
|
||||
)
|
||||
})?;
|
||||
let grids: Vec<(usize, usize)> = image_pixels
|
||||
.iter()
|
||||
.map(|t| {
|
||||
let (_, h, w) = t.dims3()?;
|
||||
Ok::<(usize, usize), candle_core::Error>((h / factor, w / factor))
|
||||
})
|
||||
.collect::<candle_core::Result<Vec<_>>>()?;
|
||||
|
||||
// Interleaved-M-RoPE 3D positions for the whole prompt, computed
|
||||
// once and sliced per chunk so image tokens get their grid
|
||||
// coordinates and text after an image resumes from the
|
||||
// compressed counter. `rope_delta` is stashed on the base model
|
||||
// for the decode that follows this prefill.
|
||||
let (text, height, width, delta) = rope::get_rope_index(tokens, image_token_id, &grids)
|
||||
.map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?;
|
||||
self.base.rope_delta = delta;
|
||||
let full_pos = rope::mrope_position_tensor(&text, &height, &width, &device)?;
|
||||
|
||||
let mut last_logits: Option<Tensor> = None;
|
||||
// Rows of `image_embeds` already spliced by earlier chunks. The
|
||||
// `<|image_pad|>` run is contiguous, so chunks consume embedding
|
||||
// rows in order.
|
||||
let mut img_off = 0usize;
|
||||
let mut start = 0usize;
|
||||
while start < tokens.len() {
|
||||
let end = (start + chunk_size).min(tokens.len());
|
||||
let chunk = &tokens[start..end];
|
||||
let input = Tensor::new(chunk, &device)?.unsqueeze(0)?;
|
||||
let pos_slice = full_pos.narrow(1, start, end - start)?;
|
||||
let n_here = chunk.iter().filter(|&&t| t == image_token_id).count();
|
||||
let logits = if n_here == 0 {
|
||||
self.forward_with_positions(&input, base_offset + start, &pos_slice, None, None)?
|
||||
} else {
|
||||
// Splice the next `n_here` patch rows at this chunk's
|
||||
// local image-pad positions.
|
||||
let rows = image_embeds.narrow(0, img_off, n_here)?;
|
||||
img_off += n_here;
|
||||
self.forward_with_positions(
|
||||
&input,
|
||||
base_offset + start,
|
||||
&pos_slice,
|
||||
Some(&rows),
|
||||
Some(image_token_id),
|
||||
)?
|
||||
};
|
||||
last_logits = Some(logits);
|
||||
start = end;
|
||||
}
|
||||
last_logits
|
||||
.ok_or_else(|| candle_core::Error::Msg("prefill_with_images_chunked: no chunks".into()))
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.base.clear_kv_cache();
|
||||
}
|
||||
|
||||
/// See [`Qwen3_5Model::snapshot_kv_cache`].
|
||||
pub fn snapshot_kv_cache(&self) -> candle_core::Result<snapshot::KvCacheSnapshot> {
|
||||
self.base.snapshot_kv_cache()
|
||||
}
|
||||
|
||||
/// See [`Qwen3_5Model::restore_kv_cache`].
|
||||
pub fn restore_kv_cache(
|
||||
&mut self,
|
||||
snap: &snapshot::KvCacheSnapshot,
|
||||
) -> candle_core::Result<()> {
|
||||
self.base.restore_kv_cache(snap)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Confirms we can deserialise the real upstream config shape.
|
||||
/// Sample taken from `Qwen/Qwen3.6-27B/config.json`, trimmed to
|
||||
/// the fields the architecture cares about. Note `rope_theta` and
|
||||
/// `partial_rotary_factor` are nested under `rope_parameters` —
|
||||
/// Qwen3-Next does NOT have a top-level `rope_theta`.
|
||||
#[test]
|
||||
fn config_deserialises_the_real_qwen3_6_shape() {
|
||||
let raw = r#"{
|
||||
"architectures": ["Qwen3_5ForConditionalGeneration"],
|
||||
"model_type": "qwen3_5",
|
||||
"image_token_id": 248056,
|
||||
"language_model_only": false,
|
||||
"text_config": {
|
||||
"vocab_size": 248064,
|
||||
"hidden_size": 5120,
|
||||
"intermediate_size": 17408,
|
||||
"num_hidden_layers": 64,
|
||||
"num_attention_heads": 64,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 256,
|
||||
"max_position_embeddings": 32768,
|
||||
"rope_parameters": {
|
||||
"mrope_interleaved": true,
|
||||
"mrope_section": [11, 11, 10],
|
||||
"partial_rotary_factor": 0.25,
|
||||
"rope_theta": 10000000,
|
||||
"rope_type": "default"
|
||||
},
|
||||
"rms_norm_eps": 1e-6,
|
||||
"tie_word_embeddings": false,
|
||||
"attn_output_gate": true,
|
||||
"full_attention_interval": 4,
|
||||
"layer_types": [
|
||||
"linear_attention", "linear_attention",
|
||||
"linear_attention", "full_attention"
|
||||
]
|
||||
}
|
||||
}"#;
|
||||
let cfg: Config = serde_json::from_str(raw).expect("parse Qwen3.6 config");
|
||||
assert_eq!(cfg.model_type, "qwen3_5");
|
||||
assert_eq!(cfg.text_config.hidden_size, 5120);
|
||||
assert_eq!(cfg.text_config.head_dim, 256);
|
||||
assert!(cfg.text_config.attn_output_gate);
|
||||
assert_eq!(cfg.text_config.full_attention_interval, Some(4));
|
||||
assert_eq!(cfg.text_config.layer_types.len(), 4);
|
||||
assert_eq!(cfg.text_config.rope_parameters.rope_theta, 10_000_000.0);
|
||||
assert!((cfg.text_config.rope_parameters.partial_rotary_factor - 0.25).abs() < 1e-6);
|
||||
}
|
||||
|
||||
/// `splice_runs` replaces (1, L, H) embedding rows at the given
|
||||
/// positions with rows from a (N_img, H) image-embedding tensor,
|
||||
/// in the order positions are supplied.
|
||||
#[test]
|
||||
fn splice_runs_replaces_at_contiguous_positions() {
|
||||
use candle_core::{DType, Device};
|
||||
|
||||
let dev = Device::Cpu;
|
||||
// (1, L=5, H=2) text embeddings — encoded as floats so the
|
||||
// assertion can spot the change without dtype conversion.
|
||||
let h_vals: Vec<f32> = vec![
|
||||
10., 11., // pos 0
|
||||
20., 21., // pos 1
|
||||
30., 31., // pos 2
|
||||
40., 41., // pos 3
|
||||
50., 51., // pos 4
|
||||
];
|
||||
let h = Tensor::from_vec(h_vals, (1, 5, 2), &dev).unwrap();
|
||||
|
||||
// Two image embeddings to splice at positions 1 and 2 (a
|
||||
// contiguous run — single image emitting two patch tokens).
|
||||
let img_vals: Vec<f32> = vec![-1., -2., -3., -4.];
|
||||
let img = Tensor::from_vec(img_vals, (2, 2), &dev).unwrap();
|
||||
|
||||
let out = splice_runs(&h, &img, &[1, 2]).unwrap();
|
||||
let flat: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
|
||||
assert_eq!(flat, vec![10., 11., -1., -2., -3., -4., 40., 41., 50., 51.]);
|
||||
let _ = DType::F32;
|
||||
}
|
||||
|
||||
/// Non-contiguous positions: two images at positions [1] and [3]
|
||||
/// each contributing one patch. `splice_runs` should iterate
|
||||
/// runs and place the corresponding image rows.
|
||||
#[test]
|
||||
fn splice_runs_handles_non_contiguous_runs() {
|
||||
use candle_core::Device;
|
||||
let dev = Device::Cpu;
|
||||
let h_vals: Vec<f32> = vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.];
|
||||
let h = Tensor::from_vec(h_vals, (1, 5, 2), &dev).unwrap();
|
||||
let img_vals: Vec<f32> = vec![-1., -2., -3., -4.];
|
||||
let img = Tensor::from_vec(img_vals, (2, 2), &dev).unwrap();
|
||||
let out = splice_runs(&h, &img, &[1, 3]).unwrap();
|
||||
let flat: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
|
||||
assert_eq!(flat, vec![1., 1., -1., -2., 3., 3., -3., -4., 5., 5.]);
|
||||
}
|
||||
}
|
||||
@@ -1,161 +0,0 @@
|
||||
//! Norm primitives for Qwen3-Next.
|
||||
//!
|
||||
//! Two reasons we can't reuse `candle_nn::RmsNorm` directly:
|
||||
//!
|
||||
//! 1. **`(1.0 + weight)` scaling.** Qwen3-Next's `Qwen3_5RMSNorm`
|
||||
//! initialises `weight` to zeros and applies `(1.0 + weight)` to
|
||||
//! the normalised vector. `candle_nn::RmsNorm` applies `weight`
|
||||
//! directly. The two are equivalent only when the operator has
|
||||
//! pre-shifted the weights — the upstream checkpoints have not. See
|
||||
//! `huggingface/transformers#29402` for the upstream PR that
|
||||
//! introduced the `(1 + w)` form to recover from the zero-init.
|
||||
//!
|
||||
//! 2. **Gated variant.** The linear-attention layer post-normalises
|
||||
//! its output by an RMSNorm *gated* with a per-element SiLU on
|
||||
//! a sibling `z` projection — fused for numerical reasons (the
|
||||
//! norm's float32 promotion has to happen before the SiLU
|
||||
//! multiply). Not a single existing candle op.
|
||||
//!
|
||||
//! Both ops accept inputs in any compute dtype; promotion to f32 for
|
||||
//! the variance calculation matches the Python reference.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{D, Module, Tensor};
|
||||
use candle_nn::var_builder::ShardedVarBuilder;
|
||||
|
||||
/// L2-normalise along the last dim with a small epsilon. Matches the
|
||||
/// `l2norm` helper in `transformers/models/qwen3_5/modeling_qwen3_5.py`
|
||||
/// — `x * rsqrt(sum(x*x) + eps)`. The linear-attention path uses this
|
||||
/// on Q and K before the delta rule when
|
||||
/// `use_qk_l2norm_in_kernel=True` (which Qwen3-Next always sets).
|
||||
pub fn l2norm(x: &Tensor, eps: f32) -> candle_core::Result<Tensor> {
|
||||
let dtype = x.dtype();
|
||||
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
|
||||
let sq = x_f32.sqr()?;
|
||||
let sum = sq.sum_keepdim(D::Minus1)?;
|
||||
let inv = (sum + eps as f64)?.sqrt()?.recip()?;
|
||||
x_f32.broadcast_mul(&inv)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
/// Qwen3-Next's RMSNorm. Stores the raw weight tensor; forward applies
|
||||
/// `(1.0 + weight) * x_normed`.
|
||||
pub struct Qwen3_5RmsNorm {
|
||||
weight: Tensor,
|
||||
eps: f32,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
impl Qwen3_5RmsNorm {
|
||||
/// Load `weight` from the ShardedVarBuilder. `vb` should already be
|
||||
/// `.pp(...)`-ed to the norm's tensor prefix.
|
||||
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
|
||||
let weight = vb
|
||||
.get(size, "weight")
|
||||
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
eps: eps as f32,
|
||||
size,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Qwen3_5RmsNorm {
|
||||
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
|
||||
let dtype = x.dtype();
|
||||
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
|
||||
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
|
||||
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
|
||||
// Promote weight to f32 and shift by 1.0 *before* multiplying.
|
||||
// Doing the (1 + w) operation in fp16 lands at -inf for the
|
||||
// bottom-of-range weights at load time.
|
||||
let w_f32 = self.weight.to_dtype(candle_core::DType::F32)?;
|
||||
let scale = (w_f32 + 1.0_f64)?;
|
||||
normed.broadcast_mul(&scale)?.to_dtype(dtype)
|
||||
}
|
||||
}
|
||||
|
||||
/// Gated RMSNorm used at the tail of `Qwen3_5GatedDeltaNet`. Equivalent
|
||||
/// to `x_normed * weight * silu(gate)` but with both the norm and the
|
||||
/// gate evaluated in float32 to avoid mid-pipeline underflow.
|
||||
///
|
||||
/// Note: unlike `Qwen3_5RmsNorm`, this variant matches the Python
|
||||
/// reference's `Qwen3_5RMSNormGated` which uses `weight` directly (not
|
||||
/// `1.0 + weight`).
|
||||
pub struct Qwen3_5RmsNormGated {
|
||||
weight: Tensor,
|
||||
eps: f32,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
impl Qwen3_5RmsNormGated {
|
||||
pub fn load(vb: &ShardedVarBuilder, size: usize, eps: f64) -> Result<Self> {
|
||||
let weight = vb
|
||||
.get(size, "weight")
|
||||
.with_context(|| format!("load '{}/weight'", vb.prefix()))?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
eps: eps as f32,
|
||||
size,
|
||||
})
|
||||
}
|
||||
|
||||
/// Direct constructor — used by unit tests that build a layer
|
||||
/// without going through a VarBuilder.
|
||||
#[cfg(test)]
|
||||
pub(crate) fn from_weight(weight: Tensor, eps: f64) -> Self {
|
||||
let size = weight.dims()[0];
|
||||
Self {
|
||||
weight,
|
||||
eps: eps as f32,
|
||||
size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
|
||||
/// `x` and `gate` share the same last-dim shape (`size`).
|
||||
pub fn forward(&self, x: &Tensor, gate: &Tensor) -> candle_core::Result<Tensor> {
|
||||
let dtype = x.dtype();
|
||||
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
|
||||
let var = x_f32.sqr()?.mean_keepdim(D::Minus1)?;
|
||||
let normed = x_f32.broadcast_mul(&(var + self.eps as f64)?.sqrt()?.recip()?)?;
|
||||
let w = self.weight.to_dtype(candle_core::DType::F32)?;
|
||||
let out = normed.broadcast_mul(&w)?;
|
||||
// SiLU on the float32 gate, multiply back into the normed
|
||||
// tensor, then cast to the model dtype.
|
||||
let g = gate.to_dtype(candle_core::DType::F32)?;
|
||||
let silu_gate = candle_nn::ops::silu(&g)?;
|
||||
(out * silu_gate)?.to_dtype(dtype)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use candle_core::Device;
|
||||
|
||||
#[test]
|
||||
fn l2norm_matches_hand_calc() {
|
||||
let x = Tensor::new(&[3.0_f32, 4.0_f32], &Device::Cpu).unwrap();
|
||||
let out = l2norm(&x, 1e-6).unwrap();
|
||||
let v: Vec<f32> = out.to_vec1().unwrap();
|
||||
// |x| = 5, so x/|x| = [0.6, 0.8] (eps is tiny).
|
||||
assert!((v[0] - 0.6).abs() < 1e-4);
|
||||
assert!((v[1] - 0.8).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn l2norm_zero_vector_is_safe_via_epsilon() {
|
||||
let x = Tensor::new(&[0.0_f32, 0.0_f32], &Device::Cpu).unwrap();
|
||||
let out = l2norm(&x, 1e-6).unwrap();
|
||||
let v: Vec<f32> = out.to_vec1().unwrap();
|
||||
assert!(v.iter().all(|x| x.is_finite()));
|
||||
}
|
||||
}
|
||||
@@ -1,579 +0,0 @@
|
||||
//! Rotary position embedding for Qwen3-Next's full-attention layers.
|
||||
//!
|
||||
//! Qwen3.6 declares **interleaved M-RoPE** (multimodal RoPE): the
|
||||
//! rotary half-dimension is split across three position axes —
|
||||
//! `[text, height, width]` per `mrope_section` (`[11,11,10]` for
|
||||
//! Qwen3.6) — interleaved per-frequency. For **text** every token's
|
||||
//! three axes carry the same position id, so the interleave is a no-op
|
||||
//! and this reduces exactly to plain RoPE. For **image** tokens the
|
||||
//! height/width axes carry the patch's 2D grid coordinates, which is
|
||||
//! how the model reads the 14×14 patch layout (without it, all patches
|
||||
//! share a height position and the image reads as vertical repetition).
|
||||
//!
|
||||
//! Two cos/sin builders feed a shared [`RotaryEmbedding::apply`]:
|
||||
//! - [`RotaryEmbedding::plain_cos_sin`] narrows the precomputed tables
|
||||
//! at a scalar position — the text / decode fast path.
|
||||
//! - [`RotaryEmbedding::mrope_cos_sin`] builds per-token cos/sin from a
|
||||
//! `(3, seq)` position-id tensor, blending the three axes' frequencies
|
||||
//! at the interleave index sets — the vision-prefill path.
|
||||
//!
|
||||
//! Rotation flavour: **GLM-style** rotate-half (candle's `rope_slow`),
|
||||
//! matching the reference Python's `apply_rotary_pos_emb` + `rotate_half`.
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{DType, Device, IndexOp, Tensor};
|
||||
|
||||
use super::TextConfig;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
/// Inverse frequencies, shape `(1, rotary_dim/2)`. Retained (beyond
|
||||
/// the precomputed `sin`/`cos` tables) so [`Self::mrope_cos_sin`] can
|
||||
/// build cos/sin from arbitrary per-axis position ids.
|
||||
inv_freq: Tensor,
|
||||
/// Per-axis column masks over the rotary half-dim, shape `(1, half)`,
|
||||
/// f32 0/1. `mask_t + mask_h + mask_w` partitions the columns; a
|
||||
/// column belongs to exactly one axis. For a non-MRoPE config
|
||||
/// `mask_t` is all-ones and the others all-zero (→ plain RoPE).
|
||||
mask_t: Tensor,
|
||||
mask_h: Tensor,
|
||||
mask_w: Tensor,
|
||||
dtype: DType,
|
||||
/// Number of dims at the head's leading edge that the rotation
|
||||
/// covers. The remaining `head_dim - rotary_dim` dims pass through
|
||||
/// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so
|
||||
/// for `head_dim = 256` only 64 dims rotate.
|
||||
rotary_dim: usize,
|
||||
head_dim: usize,
|
||||
}
|
||||
|
||||
/// Build the per-axis 0/1 column masks over the rotary half-dim from
|
||||
/// `mrope_section`. Returns `(temporal, height, width)` each length
|
||||
/// `half`. Temporal is the complement of height ∪ width, so the three
|
||||
/// masks always partition `0..half` and reduce to all-temporal (plain
|
||||
/// RoPE) when no usable section is given.
|
||||
fn mrope_masks(
|
||||
half: usize,
|
||||
section: &[usize],
|
||||
interleaved: bool,
|
||||
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
|
||||
let mut mh = vec![0f32; half];
|
||||
let mut mw = vec![0f32; half];
|
||||
if section.len() == 3 {
|
||||
if interleaved {
|
||||
// Qwen3-VL: height at columns 1,4,7,… ; width at 2,5,8,… ;
|
||||
// temporal keeps 0,3,6,… — each `take`n from `mrope_section`.
|
||||
for i in (1..half).step_by(3).take(section[1]) {
|
||||
mh[i] = 1.0;
|
||||
}
|
||||
for i in (2..half).step_by(3).take(section[2]) {
|
||||
mw[i] = 1.0;
|
||||
}
|
||||
} else {
|
||||
// Qwen2-VL: contiguous blocks [text | height | width].
|
||||
let h_start = section[0].min(half);
|
||||
let h_end = (section[0] + section[1]).min(half);
|
||||
for m in mh.iter_mut().take(h_end).skip(h_start) {
|
||||
*m = 1.0;
|
||||
}
|
||||
for m in mw.iter_mut().take(half).skip(h_end) {
|
||||
*m = 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
let mt: Vec<f32> = (0..half)
|
||||
.map(|i| {
|
||||
if mh[i] == 0.0 && mw[i] == 0.0 {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
(mt, mh, mw)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> {
|
||||
let head_dim = cfg.head_dim;
|
||||
let rope = &cfg.rope_parameters;
|
||||
let rotary_dim = (head_dim as f32 * rope.partial_rotary_factor) as usize;
|
||||
if !rotary_dim.is_multiple_of(2) {
|
||||
anyhow::bail!(
|
||||
"rotary_dim = head_dim * partial_rotary_factor = {head_dim} * {} = {rotary_dim} \
|
||||
must be even (cos/sin are paired)",
|
||||
rope.partial_rotary_factor
|
||||
);
|
||||
}
|
||||
if rotary_dim == 0 {
|
||||
anyhow::bail!(
|
||||
"rotary_dim = 0 (partial_rotary_factor = {} too small)",
|
||||
rope.partial_rotary_factor
|
||||
);
|
||||
}
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<f32> = (0..rotary_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32)
|
||||
.collect();
|
||||
let half = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, half), dev)?.to_dtype(DType::F32)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
|
||||
// MRoPE axis masks. `sum(mrope_section)` should equal `half`;
|
||||
// warn-tolerant: any shortfall just stays on the temporal axis.
|
||||
let (mt, mh, mw) = mrope_masks(half, &rope.mrope_section, rope.mrope_interleaved);
|
||||
let mask_t = Tensor::from_vec(mt, (1, half), dev)?;
|
||||
let mask_h = Tensor::from_vec(mh, (1, half), dev)?;
|
||||
let mask_w = Tensor::from_vec(mw, (1, half), dev)?;
|
||||
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||
inv_freq,
|
||||
mask_t,
|
||||
mask_h,
|
||||
mask_w,
|
||||
dtype,
|
||||
rotary_dim,
|
||||
head_dim,
|
||||
})
|
||||
}
|
||||
|
||||
/// cos/sin for a contiguous run of `seq_len` positions starting at
|
||||
/// `pos`, by narrowing the precomputed tables. The text / decode
|
||||
/// path (all three MRoPE axes equal → plain RoPE). Shape
|
||||
/// `(seq_len, rotary_dim/2)`.
|
||||
pub fn plain_cos_sin(
|
||||
&self,
|
||||
pos: usize,
|
||||
seq_len: usize,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
let cos = self.cos.narrow(0, pos, seq_len)?;
|
||||
let sin = self.sin.narrow(0, pos, seq_len)?;
|
||||
Ok((cos, sin))
|
||||
}
|
||||
|
||||
/// cos/sin from explicit per-token 3D position ids, shape
|
||||
/// `(3, seq_len)` (axes: text, height, width). Builds each axis's
|
||||
/// frequencies and blends them at the interleave index sets, so
|
||||
/// every rotary frequency slot is driven by exactly one axis.
|
||||
/// Reduces exactly to [`Self::plain_cos_sin`] when the three axes are
|
||||
/// equal. Returns cos/sin of shape `(seq_len, rotary_dim/2)`.
|
||||
pub fn mrope_cos_sin(&self, position_ids: &Tensor) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
let pos = position_ids.to_dtype(DType::F32)?;
|
||||
let (axes, seq_len) = pos.dims2()?;
|
||||
debug_assert_eq!(axes, 3, "mrope position_ids must have 3 axes");
|
||||
// Per-axis freqs: pos[a] (seq,1) @ inv_freq (1,half) → (seq,half).
|
||||
let ft = pos.i(0)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
|
||||
let fh = pos.i(1)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
|
||||
let fw = pos.i(2)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
|
||||
// Blend: each column belongs to exactly one axis (masks partition
|
||||
// the half-dim), so this picks the right axis per frequency slot.
|
||||
let blended = ft
|
||||
.broadcast_mul(&self.mask_t)?
|
||||
.add(&fh.broadcast_mul(&self.mask_h)?)?
|
||||
.add(&fw.broadcast_mul(&self.mask_w)?)?;
|
||||
let cos = blended.cos()?.to_dtype(self.dtype)?;
|
||||
let sin = blended.sin()?.to_dtype(self.dtype)?;
|
||||
Ok((cos, sin))
|
||||
}
|
||||
|
||||
/// Apply rotary to `q`, `k` (shape `(B, H, L, head_dim)`) using
|
||||
/// precomputed `cos`/`sin` of shape `(L, rotary_dim/2)`. Partial
|
||||
/// rotary: only the first `rotary_dim` dims rotate; the tail passes
|
||||
/// through unchanged.
|
||||
pub fn apply_cos_sin(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
cos: &Tensor,
|
||||
sin: &Tensor,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
let (_, _, _seq_len, head_dim_in) = q.dims4()?;
|
||||
debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch");
|
||||
if self.rotary_dim == self.head_dim {
|
||||
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, cos, sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, cos, sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
} else {
|
||||
// Partial rotation: narrow → rotate → cat the untouched tail.
|
||||
let tail = self.head_dim - self.rotary_dim;
|
||||
let q_rot = q
|
||||
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
|
||||
.contiguous()?;
|
||||
let q_pass = q.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
|
||||
let k_rot = k
|
||||
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
|
||||
.contiguous()?;
|
||||
let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
|
||||
let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, cos, sin)?;
|
||||
let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, cos, sin)?;
|
||||
let q_embed =
|
||||
Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?;
|
||||
let k_embed =
|
||||
Tensor::cat(&[&k_rotated, &k_pass.contiguous()?], candle_core::D::Minus1)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute interleaved-M-RoPE 3D position ids for a full prompt that may
|
||||
/// contain image-placeholder runs, plus the decode `rope_delta`.
|
||||
///
|
||||
/// Mirrors the reference `get_rope_index`:
|
||||
/// - text tokens advance a single running counter `c`, all three axes
|
||||
/// equal (`[c, c, c]`);
|
||||
/// - each contiguous run of `image_token_id` is one image; its tokens get
|
||||
/// `[base + t, base + h, base + w]` in row-major (t outer, h, w inner),
|
||||
/// where `base` is the counter at the run's start; after the run the
|
||||
/// counter resumes from `base + max(grid_t, grid_h, grid_w)`.
|
||||
///
|
||||
/// Returns `(text_pos, height_pos, width_pos, rope_delta)`, each pos `Vec`
|
||||
/// length `input_ids.len()`. `rope_delta = final_counter - seq_len`: add it
|
||||
/// to a plain decode offset so text resumes from the counter after the
|
||||
/// (position-compressed) image blocks.
|
||||
///
|
||||
/// Whether interleaved M-RoPE for image tokens is enabled. Default
|
||||
/// **on** — Qwen3.6 was trained with interleaved M-RoPE, and this
|
||||
/// implementation matches the HF `apply_interleaved_mrope` /
|
||||
/// `get_rope_index` reference exactly (verified column-for-column). The
|
||||
/// env var is a **kill switch**: `NEURON_MROPE=0` falls back to plain
|
||||
/// sequential positions for image tokens (the pre-M-RoPE behaviour).
|
||||
pub(crate) fn mrope_enabled() -> bool {
|
||||
std::env::var("NEURON_MROPE")
|
||||
.map(|v| {
|
||||
!matches!(
|
||||
v.trim().to_ascii_lowercase().as_str(),
|
||||
"0" | "false" | "no" | "off"
|
||||
)
|
||||
})
|
||||
.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Position ids for the forward path. Gated by [`mrope_enabled`]: when
|
||||
/// off, returns plain sequential identity positions on all three axes
|
||||
/// (`mrope_cos_sin` then reduces exactly to plain RoPE), restoring the
|
||||
/// pre-M-RoPE behaviour without touching the rest of the forward.
|
||||
pub(crate) fn get_rope_index(
|
||||
input_ids: &[u32],
|
||||
image_token_id: u32,
|
||||
grids: &[(usize, usize)],
|
||||
) -> Result<MRopeIndex> {
|
||||
if !mrope_enabled() {
|
||||
let seq: Vec<i64> = (0..input_ids.len() as i64).collect();
|
||||
return Ok((seq.clone(), seq.clone(), seq, 0));
|
||||
}
|
||||
compute_mrope_index(input_ids, image_token_id, grids)
|
||||
}
|
||||
|
||||
/// The real interleaved-M-RoPE position-id computation (always active in
|
||||
/// unit tests; gated behind [`get_rope_index`] at runtime).
|
||||
///
|
||||
/// `grids` carries the post-merge LM grid `(lm_gh, lm_gw)` for each image
|
||||
/// run, in prompt order — a run length alone cannot recover its
|
||||
/// factorisation, so the grids must be passed (#14 dynamic resolution).
|
||||
/// Each image is a still frame (`grid_t = 1`); its tokens get
|
||||
/// `[base, base + hh, base + ww]` row-major and the shared counter
|
||||
/// resumes at `base + max(lm_gh, lm_gw)`. Multi-image is correct because
|
||||
/// the counter threads across images and interleaved text.
|
||||
pub(crate) fn compute_mrope_index(
|
||||
input_ids: &[u32],
|
||||
image_token_id: u32,
|
||||
grids: &[(usize, usize)],
|
||||
) -> Result<MRopeIndex> {
|
||||
let n = input_ids.len();
|
||||
let mut text = Vec::with_capacity(n);
|
||||
let mut height = Vec::with_capacity(n);
|
||||
let mut width = Vec::with_capacity(n);
|
||||
let mut counter: i64 = 0;
|
||||
let mut i = 0;
|
||||
let mut k = 0; // index into `grids`, one per image run
|
||||
while i < n {
|
||||
if input_ids[i] == image_token_id {
|
||||
let start = i;
|
||||
while i < n && input_ids[i] == image_token_id {
|
||||
i += 1;
|
||||
}
|
||||
let run = i - start;
|
||||
let (grid_h, grid_w) = *grids.get(k).ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"get_rope_index: image run #{k} (len {run}) has no matching grid \
|
||||
({} grids supplied)",
|
||||
grids.len()
|
||||
)
|
||||
})?;
|
||||
k += 1;
|
||||
if grid_h * grid_w != run {
|
||||
anyhow::bail!(
|
||||
"get_rope_index: image run #{} length {run} != grid {grid_h}×{grid_w} = {}",
|
||||
k - 1,
|
||||
grid_h * grid_w
|
||||
);
|
||||
}
|
||||
let base = counter;
|
||||
for hh in 0..grid_h {
|
||||
for ww in 0..grid_w {
|
||||
text.push(base); // grid_t = 1 → temporal axis const
|
||||
height.push(base + hh as i64);
|
||||
width.push(base + ww as i64);
|
||||
}
|
||||
}
|
||||
counter = base + grid_h.max(grid_w) as i64;
|
||||
} else {
|
||||
text.push(counter);
|
||||
height.push(counter);
|
||||
width.push(counter);
|
||||
counter += 1;
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
if k != grids.len() {
|
||||
anyhow::bail!(
|
||||
"get_rope_index: prompt has {k} image run(s) but {} grid(s) were supplied",
|
||||
grids.len()
|
||||
);
|
||||
}
|
||||
let delta = counter - n as i64;
|
||||
Ok((text, height, width, delta))
|
||||
}
|
||||
|
||||
/// `(text_pos, height_pos, width_pos, rope_delta)` returned by
|
||||
/// [`get_rope_index`]; the three vectors combine into the `(3, seq)`
|
||||
/// MRoPE position-id tensor.
|
||||
pub(crate) type MRopeIndex = (Vec<i64>, Vec<i64>, Vec<i64>, i64);
|
||||
|
||||
/// Build the `(3, seq)` position-id tensor consumed by
|
||||
/// [`RotaryEmbedding::mrope_cos_sin`] from the three axis vectors.
|
||||
///
|
||||
/// Built directly as **f32** (positions are small integers, exact in
|
||||
/// f32 well past any context length): the freqs matmul needs float
|
||||
/// anyway, and this avoids an i64 tensor / i64→f32 cast on the GPU.
|
||||
pub(crate) fn mrope_position_tensor(
|
||||
text: &[i64],
|
||||
height: &[i64],
|
||||
width: &[i64],
|
||||
dev: &Device,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let seq = text.len();
|
||||
let mut flat = Vec::with_capacity(3 * seq);
|
||||
flat.extend(text.iter().map(|&x| x as f32));
|
||||
flat.extend(height.iter().map(|&x| x as f32));
|
||||
flat.extend(width.iter().map(|&x| x as f32));
|
||||
Tensor::from_vec(flat, (3, seq), dev)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use candle_core::IndexOp;
|
||||
|
||||
/// A TextConfig stub with Qwen3.6's rope params (head_dim 256,
|
||||
/// partial 0.25 → rotary_dim 64 → half 32; section [11,11,10]).
|
||||
fn qwen36_cfg() -> TextConfig {
|
||||
serde_json::from_value(serde_json::json!({
|
||||
"hidden_size": 5120,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 64,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 256,
|
||||
"intermediate_size": 1,
|
||||
"vocab_size": 10,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"max_position_embeddings": 64,
|
||||
"layer_types": ["full_attention"],
|
||||
"rope_parameters": {
|
||||
"rope_theta": 10000000.0,
|
||||
"partial_rotary_factor": 0.25,
|
||||
"mrope_section": [11, 11, 10],
|
||||
"mrope_interleaved": true
|
||||
}
|
||||
}))
|
||||
.expect("cfg")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mrope_masks_partition_the_half_dim() {
|
||||
let (mt, mh, mw) = mrope_masks(32, &[11, 11, 10], true);
|
||||
// Each column belongs to exactly one axis.
|
||||
for i in 0..32 {
|
||||
let s = mt[i] + mh[i] + mw[i];
|
||||
assert_eq!(s, 1.0, "column {i} covered {s} times");
|
||||
}
|
||||
assert_eq!(mt.iter().sum::<f32>(), 11.0);
|
||||
assert_eq!(mh.iter().sum::<f32>(), 11.0);
|
||||
assert_eq!(mw.iter().sum::<f32>(), 10.0);
|
||||
// Interleave: temporal 0,3,…; height 1,4,…; width 2,5,…
|
||||
assert_eq!(mt[0], 1.0);
|
||||
assert_eq!(mh[1], 1.0);
|
||||
assert_eq!(mw[2], 1.0);
|
||||
assert_eq!(mt[3], 1.0);
|
||||
}
|
||||
|
||||
/// The load-bearing invariant: when all three position axes are
|
||||
/// equal (text), `mrope_cos_sin` must reproduce `plain_cos_sin`
|
||||
/// bit-for-bit — i.e. M-RoPE is a no-op for text, so text inference
|
||||
/// is unchanged.
|
||||
#[test]
|
||||
fn mrope_reduces_to_plain_for_equal_axes() {
|
||||
let dev = Device::Cpu;
|
||||
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
|
||||
|
||||
// positions 5,6,7 on all three axes.
|
||||
let base: Vec<i64> = vec![5, 6, 7];
|
||||
let pos =
|
||||
Tensor::from_vec([base.clone(), base.clone(), base].concat(), (3, 3), &dev).unwrap();
|
||||
|
||||
let (mc, ms) = rope.mrope_cos_sin(&pos).unwrap();
|
||||
let (pc, ps) = rope.plain_cos_sin(5, 3).unwrap();
|
||||
|
||||
let dcos = (mc - pc).unwrap().abs().unwrap().max_all().unwrap();
|
||||
let dsin = (ms - ps).unwrap().abs().unwrap().max_all().unwrap();
|
||||
assert!(
|
||||
dcos.to_scalar::<f32>().unwrap() < 1e-6,
|
||||
"cos mismatch {dcos:?}"
|
||||
);
|
||||
assert!(
|
||||
dsin.to_scalar::<f32>().unwrap() < 1e-6,
|
||||
"sin mismatch {dsin:?}"
|
||||
);
|
||||
}
|
||||
|
||||
/// Hand-checked interleave: a width-axis column (index 2) must track
|
||||
/// the WIDTH position, while a temporal column (index 0) tracks the
|
||||
/// TEXT position, even when the axes differ.
|
||||
#[test]
|
||||
fn mrope_blends_axes_at_interleave_columns() {
|
||||
let dev = Device::Cpu;
|
||||
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
|
||||
let half = rope.inv_freq.dim(1).unwrap();
|
||||
let inv: Vec<f32> = rope.inv_freq.i(0).unwrap().to_vec1().unwrap();
|
||||
|
||||
// One token: text=10, height=3, width=7 — all distinct.
|
||||
let pos = Tensor::from_vec(vec![10i64, 3, 7], (3, 1), &dev).unwrap();
|
||||
let (cos, _sin) = rope.mrope_cos_sin(&pos).unwrap();
|
||||
let cos_row: Vec<f32> = cos.i(0).unwrap().to_vec1().unwrap();
|
||||
assert_eq!(cos_row.len(), half);
|
||||
|
||||
// Column 0 (temporal) → text pos 10. Column 1 (height) → 3.
|
||||
// Column 2 (width) → 7.
|
||||
assert!((cos_row[0] - (10.0 * inv[0]).cos()).abs() < 1e-5);
|
||||
assert!((cos_row[1] - (3.0 * inv[1]).cos()).abs() < 1e-5);
|
||||
assert!((cos_row[2] - (7.0 * inv[2]).cos()).abs() < 1e-5);
|
||||
assert!((cos_row[3] - (10.0 * inv[3]).cos()).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_rope_index_text_only_is_sequential() {
|
||||
let (t, h, w, delta) = compute_mrope_index(&[1, 2, 3, 4], 99, &[]).unwrap();
|
||||
assert_eq!(t, vec![0, 1, 2, 3]);
|
||||
assert_eq!(h, vec![0, 1, 2, 3]);
|
||||
assert_eq!(w, vec![0, 1, 2, 3]);
|
||||
assert_eq!(delta, 0, "no image → delta 0 → plain decode positions");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_rope_index_text_image_text() {
|
||||
// [text, image(2x2 run of 4), text]. image_token = 99, grid (2,2).
|
||||
let ids = [1u32, 99, 99, 99, 99, 2];
|
||||
let (t, h, w, delta) = compute_mrope_index(&ids, 99, &[(2, 2)]).unwrap();
|
||||
// token 0: text → 0. image base=1, grid 2x2:
|
||||
// t all = 1; h = base+row = [1,1,2,2]; w = base+col = [1,2,1,2].
|
||||
// resume from base + max(2,2) = 3. trailing text → 3.
|
||||
assert_eq!(t, vec![0, 1, 1, 1, 1, 3]);
|
||||
assert_eq!(h, vec![0, 1, 1, 2, 2, 3]);
|
||||
assert_eq!(w, vec![0, 1, 2, 1, 2, 3]);
|
||||
// final counter = 4, seq_len = 6 → delta = -2 (the 4 image tokens
|
||||
// advanced the counter by only 2).
|
||||
assert_eq!(delta, -2);
|
||||
// Decode after the prompt (offset = 6) → text position 6 + (-2) = 4.
|
||||
assert_eq!(6 + delta, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_rope_index_nonsquare_single_image() {
|
||||
// text + image(2 rows × 3 cols = 6 tokens). grid (2,3).
|
||||
let ids = [1u32, 99, 99, 99, 99, 99, 99];
|
||||
let (t, h, w, delta) = compute_mrope_index(&ids, 99, &[(2, 3)]).unwrap();
|
||||
// base = 1; row-major h = [0,0,0,1,1,1]+1, w = [0,1,2,0,1,2]+1.
|
||||
assert_eq!(t, vec![0, 1, 1, 1, 1, 1, 1]);
|
||||
assert_eq!(h, vec![0, 1, 1, 1, 2, 2, 2]);
|
||||
assert_eq!(w, vec![0, 1, 2, 3, 1, 2, 3]);
|
||||
// resume from base + max(2,3) = 4; seq_len 7, counter 4 → delta -3.
|
||||
assert_eq!(delta, 4 - 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_rope_index_two_images_different_grids() {
|
||||
// img(2x2)=4, text, img(1x3)=3. grids [(2,2),(1,3)].
|
||||
let ids = [99, 99, 99, 99, 7, 99, 99, 99];
|
||||
let (t, h, w, delta) = compute_mrope_index(&ids, 99, &[(2, 2), (1, 3)]).unwrap();
|
||||
// img1 base=0 → t=0, h=[0,0,1,1], w=[0,1,0,1]; resume max(2,2)=2.
|
||||
// text at counter 2. img2 base=3 → t=3, h=[3,3,3], w=[3,4,5];
|
||||
// resume 3+max(1,3)=6.
|
||||
assert_eq!(t, vec![0, 0, 0, 0, 2, 3, 3, 3]);
|
||||
assert_eq!(h, vec![0, 0, 1, 1, 2, 3, 3, 3]);
|
||||
assert_eq!(w, vec![0, 1, 0, 1, 2, 3, 4, 5]);
|
||||
assert_eq!(delta, 6 - 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_rope_index_on_by_default() {
|
||||
// With NEURON_MROPE unset (default ON), the runtime path returns
|
||||
// the real interleaved-M-RoPE positions. (NEURON_MROPE=0 would fall
|
||||
// back to identity; not asserted here since it depends on env.)
|
||||
let (t, h, w, _delta) = get_rope_index(&[1, 99, 99, 99, 99, 2], 99, &[(2, 2)]).unwrap();
|
||||
assert_eq!(t, vec![0, 1, 1, 1, 1, 3]);
|
||||
assert_eq!(h, vec![0, 1, 1, 2, 2, 3]);
|
||||
assert_eq!(w, vec![0, 1, 2, 1, 2, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_rope_index_grid_mismatches_error() {
|
||||
// run length != grid product.
|
||||
assert!(compute_mrope_index(&[99u32; 6], 99, &[(2, 2)]).is_err());
|
||||
// too few grids for the number of image runs.
|
||||
assert!(compute_mrope_index(&[99, 99, 7, 99], 99, &[(1, 2)]).is_err());
|
||||
// too many grids.
|
||||
assert!(compute_mrope_index(&[99, 99], 99, &[(1, 2), (1, 1)]).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn position_tensor_round_trips_through_mrope_cos_sin() {
|
||||
// get_rope_index → (3,seq) tensor → mrope_cos_sin, and confirm an
|
||||
// image token's height column tracks its grid row (not the text
|
||||
// counter), i.e. the end-to-end position plumbing is wired right.
|
||||
let dev = Device::Cpu;
|
||||
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
|
||||
let ids = [1u32, 99, 99, 99, 99]; // text + 2x2 image
|
||||
let (t, h, w, _d) = compute_mrope_index(&ids, 99, &[(2, 2)]).unwrap();
|
||||
let pos = mrope_position_tensor(&t, &h, &w, &dev).unwrap();
|
||||
assert_eq!(pos.dims(), &[3, 5]);
|
||||
let (cos, _sin) = rope.mrope_cos_sin(&pos).unwrap();
|
||||
assert_eq!(cos.dims(), &[5, rope.inv_freq.dim(1).unwrap()]);
|
||||
|
||||
let inv: Vec<f32> = rope.inv_freq.i(0).unwrap().to_vec1().unwrap();
|
||||
// Last image token (index 4): grid (h=1, w=1) → base 1 → h=2, w=2.
|
||||
// Height column (index 1) must track h-position 2, not text.
|
||||
let last: Vec<f32> = cos.i(4).unwrap().to_vec1().unwrap();
|
||||
assert!((last[1] - (2.0 * inv[1]).cos()).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_rope_index_196_is_14x14() {
|
||||
let mut ids = vec![1u32]; // one text token
|
||||
ids.extend(std::iter::repeat_n(99u32, 196));
|
||||
let (t, h, w, _delta) = compute_mrope_index(&ids, 99, &[(14, 14)]).unwrap();
|
||||
// image base = 1. Last image token (index 196) is grid (h=13,w=13).
|
||||
assert_eq!(*t.last().unwrap(), 1, "grid_t=1 → temporal const at base");
|
||||
assert_eq!(h[1], 1, "first image row at base");
|
||||
assert_eq!(w[1], 1, "first image col at base");
|
||||
assert_eq!(h[196], 1 + 13, "last image row = base + 13");
|
||||
assert_eq!(w[196], 1 + 13, "last image col = base + 13");
|
||||
}
|
||||
}
|
||||
@@ -1,299 +0,0 @@
|
||||
//! Cache-state snapshots for prefix KV caching (#11).
|
||||
//!
|
||||
//! A snapshot captures everything `clear_kv_cache` would destroy, at
|
||||
//! one consistent token boundary:
|
||||
//!
|
||||
//! - full-attention layers: the `ConcatKvCache` k/v tensors,
|
||||
//! - linear-attention layers: the GatedDeltaNet `conv_state` +
|
||||
//! `recurrent_state`,
|
||||
//! - the model-level `rope_delta` position counter.
|
||||
//!
|
||||
//! The GatedDeltaNet recurrent state cannot be rewound to an earlier
|
||||
//! token, so a snapshot is only reusable when its entire token
|
||||
//! sequence is an exact prefix of an incoming prompt — matching policy
|
||||
//! lives in `harness/prefix_cache.rs`; this module is just the state
|
||||
//! capture.
|
||||
//!
|
||||
//! ## Copy semantics
|
||||
//!
|
||||
//! Attention k/v snapshots share storage with the live cache:
|
||||
//! `ConcatKvCache::append` never mutates stored tensors in place (it
|
||||
//! `cat`s into fresh allocations), so a shallow `Tensor` clone stays
|
||||
//! valid after the live cache moves on. The GDN states are
|
||||
//! **deep-copied** in both directions (`Tensor::copy`): the CUDA
|
||||
//! delta-rule kernels update the recurrent-state buffer in place, and
|
||||
//! `flatten`/`contiguous` on an already-contiguous tensor is a view —
|
||||
//! a shared-storage snapshot would be corrupted by the next forward.
|
||||
|
||||
use candle_core::Tensor;
|
||||
|
||||
/// Per-layer captured state. Variant kind must match the layer's
|
||||
/// `AttentionKind` on restore.
|
||||
pub enum LayerKvSnapshot {
|
||||
/// `ConcatKvCache` contents. `None` when the cache was empty
|
||||
/// (a zero-token snapshot — valid but useless; the registry never
|
||||
/// stores one).
|
||||
Full(Option<(Tensor, Tensor)>),
|
||||
/// GatedDeltaNet state. Either tensor is `None` before the first
|
||||
/// forward touches it.
|
||||
Linear {
|
||||
conv_state: Option<Tensor>,
|
||||
recurrent_state: Option<Tensor>,
|
||||
},
|
||||
}
|
||||
|
||||
/// One consistent cache snapshot of a `Qwen3_5Model` (or its TP
|
||||
/// mirror `tp_qwen3_5::TpQwen3_5Model`, whose per-rank shard state
|
||||
/// has the same shape) at a token boundary. Fields are `pub(crate)`
|
||||
/// so the TP module can construct/consume the same type; holders
|
||||
/// outside the harness only ever pass it back to `restore_kv_cache`.
|
||||
pub struct KvCacheSnapshot {
|
||||
pub(crate) layers: Vec<LayerKvSnapshot>,
|
||||
pub(crate) rope_delta: i64,
|
||||
}
|
||||
|
||||
impl KvCacheSnapshot {
|
||||
/// Number of layer snapshots held (test/diagnostic helper).
|
||||
pub fn layer_count(&self) -> usize {
|
||||
self.layers.len()
|
||||
}
|
||||
|
||||
/// Total bytes of tensor data held by this snapshot. Used for the
|
||||
/// prefix-cache VRAM budget. Attention k/v shares storage with the
|
||||
/// live cache at capture time, but the live cache is cleared or
|
||||
/// replaced before the next request, so counting the full size is
|
||||
/// the honest steady-state figure.
|
||||
pub fn size_bytes(&self) -> u64 {
|
||||
fn t_bytes(t: &Tensor) -> u64 {
|
||||
(t.elem_count() * t.dtype().size_in_bytes()) as u64
|
||||
}
|
||||
self.layers
|
||||
.iter()
|
||||
.map(|l| match l {
|
||||
LayerKvSnapshot::Full(Some((k, v))) => t_bytes(k) + t_bytes(v),
|
||||
LayerKvSnapshot::Full(None) => 0,
|
||||
LayerKvSnapshot::Linear {
|
||||
conv_state,
|
||||
recurrent_state,
|
||||
} => {
|
||||
conv_state.as_ref().map(t_bytes).unwrap_or(0)
|
||||
+ recurrent_state.as_ref().map(t_bytes).unwrap_or(0)
|
||||
}
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::{Qwen3_5Model, RopeParameters, TextConfig};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Tiny two-layer config covering both attention kinds.
|
||||
fn tiny_config() -> TextConfig {
|
||||
TextConfig {
|
||||
vocab_size: 32,
|
||||
hidden_size: 16,
|
||||
intermediate_size: 32,
|
||||
num_hidden_layers: 2,
|
||||
num_attention_heads: 2,
|
||||
num_key_value_heads: 1,
|
||||
head_dim: 8,
|
||||
max_position_embeddings: 64,
|
||||
rope_parameters: RopeParameters {
|
||||
rope_theta: 10000.0,
|
||||
partial_rotary_factor: 0.5,
|
||||
rope_type: None,
|
||||
mrope_section: Vec::new(),
|
||||
mrope_interleaved: false,
|
||||
},
|
||||
rms_norm_eps: 1e-6,
|
||||
tie_word_embeddings: true,
|
||||
attn_output_gate: true,
|
||||
layer_types: vec!["linear_attention".into(), "full_attention".into()],
|
||||
full_attention_interval: Some(4),
|
||||
hidden_act: "silu".into(),
|
||||
linear_num_value_heads: 4,
|
||||
linear_num_key_heads: 2,
|
||||
linear_key_head_dim: 4,
|
||||
linear_value_head_dim: 4,
|
||||
linear_conv_kernel_dim: 4,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a Qwen3_5Model from random weights written to a temp
|
||||
/// safetensors file — the same `ShardedVarBuilder` path the real
|
||||
/// loader uses.
|
||||
fn tiny_model(cfg: &TextConfig) -> Qwen3_5Model {
|
||||
let dev = Device::Cpu;
|
||||
let randn = |shape: &[usize]| Tensor::randn(0f32, 0.2f32, shape, &dev).unwrap();
|
||||
|
||||
let h = cfg.hidden_size;
|
||||
let inter = cfg.intermediate_size;
|
||||
let key_dim = cfg.linear_key_head_dim * cfg.linear_num_key_heads;
|
||||
let value_dim = cfg.linear_value_head_dim * cfg.linear_num_value_heads;
|
||||
let conv_dim = key_dim * 2 + value_dim;
|
||||
let nv = cfg.linear_num_value_heads;
|
||||
let hd = cfg.head_dim;
|
||||
let q_out = cfg.num_attention_heads * hd * 2;
|
||||
let kv_out = cfg.num_key_value_heads * hd;
|
||||
|
||||
let mut t: HashMap<String, Tensor> = HashMap::new();
|
||||
let p = "model.language_model";
|
||||
t.insert(
|
||||
format!("{p}.embed_tokens.weight"),
|
||||
randn(&[cfg.vocab_size, h]),
|
||||
);
|
||||
t.insert(format!("{p}.norm.weight"), randn(&[h]));
|
||||
for (i, kind) in cfg.layer_types.iter().enumerate() {
|
||||
let lp = format!("{p}.layers.{i}");
|
||||
t.insert(format!("{lp}.input_layernorm.weight"), randn(&[h]));
|
||||
t.insert(format!("{lp}.post_attention_layernorm.weight"), randn(&[h]));
|
||||
t.insert(format!("{lp}.mlp.gate_proj.weight"), randn(&[inter, h]));
|
||||
t.insert(format!("{lp}.mlp.up_proj.weight"), randn(&[inter, h]));
|
||||
t.insert(format!("{lp}.mlp.down_proj.weight"), randn(&[h, inter]));
|
||||
match kind.as_str() {
|
||||
"linear_attention" => {
|
||||
let ap = format!("{lp}.linear_attn");
|
||||
t.insert(format!("{ap}.in_proj_qkv.weight"), randn(&[conv_dim, h]));
|
||||
t.insert(format!("{ap}.in_proj_z.weight"), randn(&[value_dim, h]));
|
||||
t.insert(format!("{ap}.in_proj_b.weight"), randn(&[nv, h]));
|
||||
t.insert(format!("{ap}.in_proj_a.weight"), randn(&[nv, h]));
|
||||
t.insert(format!("{ap}.out_proj.weight"), randn(&[h, value_dim]));
|
||||
t.insert(
|
||||
format!("{ap}.conv1d.weight"),
|
||||
randn(&[conv_dim, 1, cfg.linear_conv_kernel_dim]),
|
||||
);
|
||||
t.insert(format!("{ap}.dt_bias"), randn(&[nv]));
|
||||
t.insert(format!("{ap}.A_log"), randn(&[nv]));
|
||||
t.insert(
|
||||
format!("{ap}.norm.weight"),
|
||||
randn(&[cfg.linear_value_head_dim]),
|
||||
);
|
||||
}
|
||||
"full_attention" => {
|
||||
let ap = format!("{lp}.self_attn");
|
||||
t.insert(format!("{ap}.q_proj.weight"), randn(&[q_out, h]));
|
||||
t.insert(format!("{ap}.k_proj.weight"), randn(&[kv_out, h]));
|
||||
t.insert(format!("{ap}.v_proj.weight"), randn(&[kv_out, h]));
|
||||
t.insert(
|
||||
format!("{ap}.o_proj.weight"),
|
||||
randn(&[h, cfg.num_attention_heads * hd]),
|
||||
);
|
||||
t.insert(format!("{ap}.q_norm.weight"), randn(&[hd]));
|
||||
t.insert(format!("{ap}.k_norm.weight"), randn(&[hd]));
|
||||
}
|
||||
other => panic!("unexpected layer type {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
let dir = tempfile::tempdir().expect("tempdir");
|
||||
let path = dir.path().join("model.safetensors");
|
||||
candle_core::safetensors::save(&t, &path).expect("save safetensors");
|
||||
// SAFETY: mmap of a file this test just wrote and nothing else
|
||||
// mutates — same justification as the real loader.
|
||||
let vb = unsafe {
|
||||
candle_nn::var_builder::ShardedSafeTensors::var_builder(
|
||||
&[path.clone()],
|
||||
DType::F32,
|
||||
&dev,
|
||||
)
|
||||
.expect("build ShardedVarBuilder")
|
||||
};
|
||||
Qwen3_5Model::load(cfg, &vb).expect("load tiny qwen3_5 model")
|
||||
}
|
||||
|
||||
fn forward_tokens(model: &mut Qwen3_5Model, tokens: &[u32], offset: usize) -> Vec<f32> {
|
||||
let input = Tensor::new(tokens, &Device::Cpu)
|
||||
.unwrap()
|
||||
.unsqueeze(0)
|
||||
.unwrap();
|
||||
let hidden = model.forward(&input, offset).unwrap();
|
||||
// Last-position hidden row — what the lm_head would consume.
|
||||
let (_, l, _) = hidden.dims3().unwrap();
|
||||
hidden
|
||||
.narrow(1, l - 1, 1)
|
||||
.unwrap()
|
||||
.flatten_all()
|
||||
.unwrap()
|
||||
.to_vec1()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn max_abs_diff(a: &[f32], b: &[f32]) -> f32 {
|
||||
assert_eq!(a.len(), b.len());
|
||||
a.iter()
|
||||
.zip(b)
|
||||
.map(|(x, y)| (x - y).abs())
|
||||
.fold(0f32, f32::max)
|
||||
}
|
||||
|
||||
/// The gold test for #11: prefill a prefix, snapshot, perturb the
|
||||
/// live state with unrelated tokens, restore, prefill only the
|
||||
/// suffix — the result must match a fresh full prefill. Exercises
|
||||
/// attention KV, GDN conv/recurrent state, and offset bookkeeping
|
||||
/// in one pass; the perturbation step would corrupt a
|
||||
/// shared-storage (non-deep-copied) GDN snapshot.
|
||||
#[test]
|
||||
fn restore_then_suffix_matches_full_prefill() {
|
||||
let cfg = tiny_config();
|
||||
let mut model = tiny_model(&cfg);
|
||||
|
||||
let prefix: &[u32] = &[1, 2, 3];
|
||||
let suffix: &[u32] = &[4, 5, 6];
|
||||
let full: Vec<u32> = prefix.iter().chain(suffix).copied().collect();
|
||||
|
||||
model.clear_kv_cache();
|
||||
let h_full = forward_tokens(&mut model, &full, 0);
|
||||
|
||||
model.clear_kv_cache();
|
||||
forward_tokens(&mut model, prefix, 0);
|
||||
let snap = model.snapshot_kv_cache().expect("snapshot");
|
||||
assert_eq!(snap.layer_count(), 2);
|
||||
assert!(snap.size_bytes() > 0);
|
||||
|
||||
// Advance the live state past the snapshot boundary — a
|
||||
// different continuation, as a subsequent request would be.
|
||||
forward_tokens(&mut model, &[9, 8], prefix.len());
|
||||
|
||||
model.restore_kv_cache(&snap).expect("restore");
|
||||
let h_restored = forward_tokens(&mut model, suffix, prefix.len());
|
||||
let diff = max_abs_diff(&h_full, &h_restored);
|
||||
assert!(diff < 1e-4, "restored-prefix forward diverged: {diff}");
|
||||
|
||||
// The snapshot must survive restore + forward cycles (deep
|
||||
// copy of the in-place-mutated GDN state): restore again and
|
||||
// expect the identical result.
|
||||
model.restore_kv_cache(&snap).expect("second restore");
|
||||
let h_again = forward_tokens(&mut model, suffix, prefix.len());
|
||||
let diff = max_abs_diff(&h_restored, &h_again);
|
||||
assert!(diff < 1e-6, "second restore diverged: {diff}");
|
||||
}
|
||||
|
||||
/// Restoring must fully replace the live state, not blend with it
|
||||
/// — a divergent continuation after restore equals the same
|
||||
/// continuation after a fresh prefill of the prefix.
|
||||
#[test]
|
||||
fn restore_replaces_live_state() {
|
||||
let cfg = tiny_config();
|
||||
let mut model = tiny_model(&cfg);
|
||||
|
||||
let prefix: &[u32] = &[7, 7, 2, 5];
|
||||
let cont: &[u32] = &[11, 13];
|
||||
|
||||
model.clear_kv_cache();
|
||||
forward_tokens(&mut model, prefix, 0);
|
||||
let h_fresh = forward_tokens(&mut model, cont, prefix.len());
|
||||
|
||||
model.clear_kv_cache();
|
||||
forward_tokens(&mut model, prefix, 0);
|
||||
let snap = model.snapshot_kv_cache().expect("snapshot");
|
||||
forward_tokens(&mut model, &[3, 1, 4, 1, 5], prefix.len());
|
||||
model.restore_kv_cache(&snap).expect("restore");
|
||||
let h_restored = forward_tokens(&mut model, cont, prefix.len());
|
||||
|
||||
let diff = max_abs_diff(&h_fresh, &h_restored);
|
||||
assert!(diff < 1e-5, "restore did not replace live state: {diff}");
|
||||
}
|
||||
}
|
||||
@@ -1,843 +0,0 @@
|
||||
//! Qwen3.6 vision tower.
|
||||
//!
|
||||
//! 27 pre-norm ViT blocks with **LayerNorm** (with biases — not the
|
||||
//! `(1+w)·x` RmsNorm the language model uses), fused QKV attention,
|
||||
//! GELU-tanh MLP. Followed by a `merger` that LayerNorms each
|
||||
//! 1152-dim vision token, spatially 2×2-merges them into 4608-dim
|
||||
//! groups, and projects to the LM's 5120-dim hidden via
|
||||
//! `linear_fc1 → GELU → linear_fc2`.
|
||||
//!
|
||||
//! Architecture spec sourced from beast's cached Qwen3.6-27B
|
||||
//! safetensors header (Stage A0, see
|
||||
//! `doc/vision-qwen3_6-spec.md`). All weight shapes confirmed
|
||||
//! from the live `.safetensors` headers, not inferred.
|
||||
//!
|
||||
//! **Conv3d wrinkle.** The published `patch_embed.proj.weight` is 5D
|
||||
//! `[1152, 3, 2, 16, 16]` — a 3D conv with kernel
|
||||
//! `(t=2, h=16, w=16)`. Candle 0.10 has no Conv3d. For static images
|
||||
//! we get away with a trick: when the temporal patch size is 2 and we
|
||||
//! duplicate the still image along the temporal axis (`T = 2`,
|
||||
//! frame_0 == frame_1), the Conv3d output equals a Conv2d run with
|
||||
//! the *sum* of the two temporal weight slices:
|
||||
//!
|
||||
//! ```text
|
||||
//! output = W_0 · frame_0 + W_1 · frame_1 + bias
|
||||
//! = (W_0 + W_1) · frame + bias (static image)
|
||||
//! ```
|
||||
//!
|
||||
//! So at load we sum-collapse the temporal axis and use a 4D
|
||||
//! `Conv2d` kernel. Video support would have to do the real Conv3d
|
||||
//! (different frames mean the trick fails) — tracked alongside the
|
||||
//! dynamic-resolution work in issue #14.
|
||||
//!
|
||||
//! Forward signature (Stage A — no LM splice yet):
|
||||
//!
|
||||
//! ```text
|
||||
//! fn forward(&self, image: &Tensor) -> Result<Tensor>
|
||||
//! ```
|
||||
//!
|
||||
//! `image` is `(3, H, W)` f32, normalised by `preprocess::preprocess`.
|
||||
//! Returns `(N_lm_tokens, out_hidden_size)` post-merger tokens ready
|
||||
//! to splice into the LM's input embeddings at `<|image_pad|>`
|
||||
//! positions. For Qwen3.6 at 448×448 → 28×28 patches → 14×14 = 196
|
||||
//! LM tokens of dim 5120.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{D, DType, Device, IndexOp, Module, Tensor};
|
||||
use candle_nn::var_builder::ShardedVarBuilder;
|
||||
use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear};
|
||||
use serde::Deserialize;
|
||||
|
||||
fn env_truthy(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.map(|v| {
|
||||
matches!(
|
||||
v.trim().to_ascii_lowercase().as_str(),
|
||||
"1" | "true" | "yes" | "on"
|
||||
)
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Legacy escape hatch: when set, use the original Stage-A sequential
|
||||
/// `pos_embed` lookup instead of the bilinear grid interpolation.
|
||||
/// Default off (interpolation on) — for A/B comparison only.
|
||||
fn vision_legacy_pos() -> bool {
|
||||
env_truthy("NEURON_VISION_LEGACY_POS")
|
||||
}
|
||||
|
||||
/// Legacy escape hatch: when set, skip the 2D vision rotary in the ViT
|
||||
/// attention (the original Stage-A behaviour). Default off (rotary on)
|
||||
/// — for A/B comparison only.
|
||||
fn vision_legacy_rope() -> bool {
|
||||
env_truthy("NEURON_VISION_LEGACY_ROPE")
|
||||
}
|
||||
|
||||
/// Qwen3.6 vision tower hyperparameters. Mirrors the `vision_config`
|
||||
/// block of `config.json`. Only the fields we actually need are
|
||||
/// captured; serde tolerates the rest.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct VisionConfig {
|
||||
/// Number of ViT blocks (`depth: 27` for Qwen3.6).
|
||||
pub depth: usize,
|
||||
/// Vision-token dimension throughout the tower (1152 for Qwen3.6).
|
||||
pub hidden_size: usize,
|
||||
/// MLP intermediate dim (4304).
|
||||
pub intermediate_size: usize,
|
||||
/// Attention head count (16). `head_dim = hidden_size / num_heads`.
|
||||
pub num_heads: usize,
|
||||
/// Number of slots in the learned position embedding (2304).
|
||||
/// Caps the maximum image patch count.
|
||||
pub num_position_embeddings: usize,
|
||||
/// Spatial patch edge in pixels (16).
|
||||
pub patch_size: usize,
|
||||
/// Temporal kernel depth in the patch embed (2 for Qwen3.6 — we
|
||||
/// collapse this into a single Conv2d for static-image inference;
|
||||
/// see the module-level Conv3d wrinkle).
|
||||
pub temporal_patch_size: usize,
|
||||
/// Patches grouped per LM token by the merger (2 → 2×2 = 4
|
||||
/// patches per LM token).
|
||||
pub spatial_merge_size: usize,
|
||||
/// Vision input channels (3, RGB).
|
||||
pub in_channels: usize,
|
||||
/// Merger output dim — matches the LM's `hidden_size` (5120 for
|
||||
/// Qwen3.6). The merger projects from vision dim → LM dim.
|
||||
pub out_hidden_size: usize,
|
||||
}
|
||||
|
||||
const LAYER_NORM_EPS: f64 = 1e-6;
|
||||
/// Number of LM tokens emitted by the merger per vision-token group.
|
||||
const LM_TOKENS_PER_MERGE_GROUP: usize = 1;
|
||||
|
||||
/// One ViT block: pre-LN → attn → residual; pre-LN → MLP → residual.
|
||||
struct VisionBlock {
|
||||
norm1: LayerNorm,
|
||||
qkv: Linear,
|
||||
proj: Linear,
|
||||
norm2: LayerNorm,
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
}
|
||||
|
||||
impl VisionBlock {
|
||||
fn load(cfg: &VisionConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let head_dim = h / cfg.num_heads;
|
||||
let norm1 = layer_norm(vb.pp("norm1"), h)?;
|
||||
let qkv = linear(vb.pp("attn.qkv"), h, 3 * h)?;
|
||||
let proj = linear(vb.pp("attn.proj"), h, h)?;
|
||||
let norm2 = layer_norm(vb.pp("norm2"), h)?;
|
||||
let fc1 = linear(vb.pp("mlp.linear_fc1"), h, cfg.intermediate_size)?;
|
||||
let fc2 = linear(vb.pp("mlp.linear_fc2"), cfg.intermediate_size, h)?;
|
||||
Ok(Self {
|
||||
norm1,
|
||||
qkv,
|
||||
proj,
|
||||
norm2,
|
||||
fc1,
|
||||
fc2,
|
||||
num_heads: cfg.num_heads,
|
||||
head_dim,
|
||||
})
|
||||
}
|
||||
|
||||
/// `x`: `(N, hidden_size)` un-batched. `rotary`: optional
|
||||
/// `(cos, sin)` each `(N, head_dim/2)` — the 2D vision rotary applied
|
||||
/// to q/k. Returns same shape.
|
||||
fn forward(&self, x: &Tensor, rotary: Option<&(Tensor, Tensor)>) -> Result<Tensor> {
|
||||
let attn_in = self.norm1.forward(x)?;
|
||||
let attn_out = self.attention(&attn_in, rotary)?;
|
||||
let x = x.add(&attn_out)?;
|
||||
let mlp_in = self.norm2.forward(&x)?;
|
||||
let mlp_out = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&mlp_in)?)?)?;
|
||||
x.add(&mlp_out).map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Multi-head self-attention over the patch sequence. No causal
|
||||
/// mask — every patch attends to every other patch. When `rotary` is
|
||||
/// given, the 2D vision rotary (row/col position) is applied to q, k
|
||||
/// before the scores, matching HF `apply_rotary_pos_emb_vision`
|
||||
/// (`rope_slow` is the same rotate-half form).
|
||||
fn attention(&self, x: &Tensor, rotary: Option<&(Tensor, Tensor)>) -> Result<Tensor> {
|
||||
let (n, hidden) = x.dims2()?;
|
||||
// qkv: (N, 3*hidden). Split into Q, K, V each (N, hidden).
|
||||
let qkv = self.qkv.forward(x)?;
|
||||
let qkv = qkv.reshape((n, 3, self.num_heads, self.head_dim))?;
|
||||
// Transpose to (3, num_heads, N, head_dim) for per-head views.
|
||||
let qkv = qkv.permute((1, 2, 0, 3))?.contiguous()?;
|
||||
let q = qkv.i(0)?;
|
||||
let k = qkv.i(1)?;
|
||||
let v = qkv.i(2)?;
|
||||
// 2D vision rotary on q, k (full head_dim; rotate-half form).
|
||||
let (q, k) = match rotary {
|
||||
Some((cos, sin)) => {
|
||||
let q = candle_nn::rotary_emb::rope_slow(&q.unsqueeze(0)?, cos, sin)?.squeeze(0)?;
|
||||
let k = candle_nn::rotary_emb::rope_slow(&k.unsqueeze(0)?, cos, sin)?.squeeze(0)?;
|
||||
(q, k)
|
||||
}
|
||||
None => (q, k),
|
||||
};
|
||||
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||
// (num_heads, N, head_dim) @ (num_heads, head_dim, N) -> (num_heads, N, N)
|
||||
let scores = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
|
||||
let scores = (scores * scale)?;
|
||||
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||
// (num_heads, N, N) @ (num_heads, N, head_dim) -> (num_heads, N, head_dim)
|
||||
let out = probs.matmul(&v)?;
|
||||
// Merge heads back: (N, num_heads, head_dim) -> (N, hidden).
|
||||
let out = out.permute((1, 0, 2))?.contiguous()?.reshape((n, hidden))?;
|
||||
self.proj.forward(&out).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
/// `merger`: LayerNorm per token → spatial 2×2 merge (concat 4
|
||||
/// adjacent tokens into one 4608-dim vector) → fc1 → GELU-tanh →
|
||||
/// fc2. Output dim is the LM's hidden_size.
|
||||
struct VisionMerger {
|
||||
norm: LayerNorm,
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
merge_input_dim: usize,
|
||||
spatial_merge_size: usize,
|
||||
}
|
||||
|
||||
impl VisionMerger {
|
||||
fn load(cfg: &VisionConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let merge = cfg.spatial_merge_size;
|
||||
let merge_input_dim = h * merge * merge;
|
||||
let norm = layer_norm(vb.pp("norm"), h)?;
|
||||
let fc1 = linear(vb.pp("linear_fc1"), merge_input_dim, merge_input_dim)?;
|
||||
let fc2 = linear(vb.pp("linear_fc2"), merge_input_dim, cfg.out_hidden_size)?;
|
||||
Ok(Self {
|
||||
norm,
|
||||
fc1,
|
||||
fc2,
|
||||
merge_input_dim,
|
||||
spatial_merge_size: merge,
|
||||
})
|
||||
}
|
||||
|
||||
/// `tokens`: `(grid_h, grid_w, hidden_size)`. The merger reshapes
|
||||
/// each `merge×merge` block of adjacent patches into a single
|
||||
/// concatenated vector, then projects.
|
||||
///
|
||||
/// `grid_h` and `grid_w` must both be multiples of
|
||||
/// `spatial_merge_size`. Returns
|
||||
/// `(grid_h/merge × grid_w/merge, out_hidden_size)`.
|
||||
fn forward(&self, tokens: &Tensor) -> Result<Tensor> {
|
||||
let (gh, gw, h) = tokens.dims3()?;
|
||||
let m = self.spatial_merge_size;
|
||||
anyhow::ensure!(
|
||||
gh.is_multiple_of(m) && gw.is_multiple_of(m),
|
||||
"merger expects spatial dims divisible by merge_size={m}; got ({gh}, {gw})"
|
||||
);
|
||||
let tokens = self.norm.forward(tokens)?;
|
||||
// (gh, gw, h) -> (gh/m, m, gw/m, m, h) -> (gh/m, gw/m, m, m, h)
|
||||
// -> flatten last three -> (gh/m, gw/m, m*m*h) -> (N_lm, merge_input_dim)
|
||||
let out_h = gh / m;
|
||||
let out_w = gw / m;
|
||||
let merged = tokens
|
||||
.reshape((out_h, m, out_w, m, h))?
|
||||
.permute((0, 2, 1, 3, 4))?
|
||||
.contiguous()?
|
||||
.reshape((out_h * out_w, self.merge_input_dim))?;
|
||||
let hidden = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&merged)?)?)?;
|
||||
Ok(hidden)
|
||||
}
|
||||
}
|
||||
|
||||
/// 2D rotary position embedding for the vision tower. Each patch's
|
||||
/// `head_dim` rotates by its `(row, col)` grid coordinates: the first
|
||||
/// half of the rotary freqs are driven by the row position, the second
|
||||
/// half by the column. Mirrors HF `Qwen3VLVisionRotaryEmbedding` +
|
||||
/// `rot_pos_emb` (θ = 10000, `dim = head_dim/2`).
|
||||
struct VisionRotaryEmbedding {
|
||||
/// `(half,)` f32, `half = head_dim/4` freqs per spatial axis.
|
||||
inv_freq: Vec<f32>,
|
||||
}
|
||||
|
||||
impl VisionRotaryEmbedding {
|
||||
fn new(head_dim: usize) -> Self {
|
||||
// HF: Qwen3VLVisionRotaryEmbedding(head_dim // 2), theta 10000.
|
||||
let dim = head_dim / 2;
|
||||
let theta = 10000f32;
|
||||
let inv_freq = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / theta.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
Self { inv_freq }
|
||||
}
|
||||
|
||||
/// cos/sin for a `gh×gw` patch grid in **row-major** order. Returns
|
||||
/// `(cos, sin)` each `(gh*gw, head_dim/2)`: per patch, the row-axis
|
||||
/// freqs `row·inv_freq` followed by the col-axis freqs `col·inv_freq`
|
||||
/// (then `rope_slow` duplicates them across the full head_dim).
|
||||
fn cos_sin(
|
||||
&self,
|
||||
gh: usize,
|
||||
gw: usize,
|
||||
dev: &Device,
|
||||
dtype: DType,
|
||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||
let half = self.inv_freq.len();
|
||||
let n = gh * gw;
|
||||
let mut data = Vec::with_capacity(n * 2 * half);
|
||||
for hi in 0..gh {
|
||||
for wi in 0..gw {
|
||||
for &f in &self.inv_freq {
|
||||
data.push(hi as f32 * f);
|
||||
}
|
||||
for &f in &self.inv_freq {
|
||||
data.push(wi as f32 * f);
|
||||
}
|
||||
}
|
||||
}
|
||||
let freqs = Tensor::from_vec(data, (n, 2 * half), dev)?;
|
||||
let cos = freqs.cos()?.to_dtype(dtype)?;
|
||||
let sin = freqs.sin()?.to_dtype(dtype)?;
|
||||
Ok((cos, sin))
|
||||
}
|
||||
}
|
||||
|
||||
/// The vision tower itself.
|
||||
pub struct VisionTower {
|
||||
/// Sum-collapsed temporal kernel (Conv2d, see module doc).
|
||||
patch_embed: Conv2d,
|
||||
pos_embed: Embedding,
|
||||
rotary: VisionRotaryEmbedding,
|
||||
blocks: Vec<VisionBlock>,
|
||||
merger: VisionMerger,
|
||||
config: VisionConfig,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl VisionTower {
|
||||
/// Load from a `ShardedVarBuilder` rooted at the safetensors
|
||||
/// `model.visual.` prefix. Caller is responsible for the `pp` —
|
||||
/// see `Qwen3_5ForCausalLM::new` (Stage A4).
|
||||
pub fn load(cfg: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
|
||||
let dtype = vb.dtype();
|
||||
let device = vb.device().clone();
|
||||
|
||||
// patch_embed.proj is published as 5D Conv3d weight; we
|
||||
// sum-collapse the temporal axis (size = temporal_patch_size)
|
||||
// to get a 4D Conv2d kernel. This is exact for the static-
|
||||
// image case where T = temporal_patch_size frames are
|
||||
// identical (i.e. the input was duplicated along T).
|
||||
let raw_weight = vb
|
||||
.pp("patch_embed.proj")
|
||||
.get(
|
||||
(
|
||||
cfg.hidden_size,
|
||||
cfg.in_channels,
|
||||
cfg.temporal_patch_size,
|
||||
cfg.patch_size,
|
||||
cfg.patch_size,
|
||||
),
|
||||
"weight",
|
||||
)
|
||||
.context("load model.visual.patch_embed.proj.weight (5D Conv3d kernel)")?;
|
||||
// Sum along the temporal axis (dim 2) — see module doc-comment.
|
||||
let folded = raw_weight.sum(2)?; // -> (hidden, in_channels, patch, patch)
|
||||
let proj_bias = vb
|
||||
.pp("patch_embed.proj")
|
||||
.get(cfg.hidden_size, "bias")
|
||||
.context("load model.visual.patch_embed.proj.bias")?;
|
||||
let conv_cfg = Conv2dConfig {
|
||||
stride: cfg.patch_size,
|
||||
..Default::default()
|
||||
};
|
||||
let patch_embed = Conv2d::new(folded, Some(proj_bias), conv_cfg);
|
||||
|
||||
let pos_embed_weight = vb
|
||||
.pp("pos_embed")
|
||||
.get((cfg.num_position_embeddings, cfg.hidden_size), "weight")
|
||||
.context("load model.visual.pos_embed.weight")?;
|
||||
let pos_embed = Embedding::new(pos_embed_weight, cfg.hidden_size);
|
||||
let rotary = VisionRotaryEmbedding::new(cfg.hidden_size / cfg.num_heads);
|
||||
|
||||
let blocks_vb = vb.pp("blocks");
|
||||
let mut blocks = Vec::with_capacity(cfg.depth);
|
||||
for i in 0..cfg.depth {
|
||||
blocks.push(
|
||||
VisionBlock::load(&cfg, &blocks_vb.pp(i))
|
||||
.with_context(|| format!("load vision block {i}"))?,
|
||||
);
|
||||
}
|
||||
let merger = VisionMerger::load(&cfg, &vb.pp("merger")).context("load vision merger")?;
|
||||
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
pos_embed,
|
||||
rotary,
|
||||
blocks,
|
||||
merger,
|
||||
config: cfg,
|
||||
dtype,
|
||||
device,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &VisionConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Number of LM tokens this tower emits for an `(H, W)` pixel
|
||||
/// image after the merger. Equal to
|
||||
/// `(H / patch_size / spatial_merge_size) * (W / patch_size / spatial_merge_size)`.
|
||||
pub fn lm_tokens_for(&self, h: u32, w: u32) -> usize {
|
||||
let m = self.config.spatial_merge_size;
|
||||
let patch = self.config.patch_size;
|
||||
let gh = (h as usize) / patch / m;
|
||||
let gw = (w as usize) / patch / m;
|
||||
gh * gw * LM_TOKENS_PER_MERGE_GROUP
|
||||
}
|
||||
|
||||
/// Bilinearly interpolate the learned `pos_embed` grid (a
|
||||
/// `num_grid_per_side × num_grid_per_side` table, 48×48 for Qwen3.6)
|
||||
/// onto the actual `gh × gw` patch grid, in **row-major** patch
|
||||
/// order. Port of the HF `fast_pos_embed_interpolate`: for each patch
|
||||
/// at fractional grid coord `(linspace(0, ngrid-1, gh)[hi],
|
||||
/// linspace(0, ngrid-1, gw)[wi])`, blend the 4 surrounding grid
|
||||
/// entries by bilinear weights. Returns `(gh*gw, hidden)` in
|
||||
/// `self.dtype`.
|
||||
fn interpolated_pos_embed(&self, gh: usize, gw: usize) -> Result<Tensor> {
|
||||
let ngrid = (self.config.num_position_embeddings as f64).sqrt().round() as usize;
|
||||
anyhow::ensure!(
|
||||
ngrid * ngrid == self.config.num_position_embeddings,
|
||||
"num_position_embeddings {} is not a perfect square",
|
||||
self.config.num_position_embeddings
|
||||
);
|
||||
// Evenly-spaced fractional indices into the [0, ngrid-1] grid.
|
||||
let lin = |n: usize| -> Vec<f64> {
|
||||
if n <= 1 {
|
||||
vec![0.0]
|
||||
} else {
|
||||
let step = (ngrid - 1) as f64 / (n - 1) as f64;
|
||||
(0..n).map(|i| i as f64 * step).collect()
|
||||
}
|
||||
};
|
||||
let hs = lin(gh);
|
||||
let ws = lin(gw);
|
||||
let n = gh * gw;
|
||||
|
||||
// Four corner index sets + bilinear weight sets, row-major.
|
||||
let mut idx: [Vec<u32>; 4] = [
|
||||
Vec::with_capacity(n),
|
||||
Vec::with_capacity(n),
|
||||
Vec::with_capacity(n),
|
||||
Vec::with_capacity(n),
|
||||
];
|
||||
let mut wts: [Vec<f32>; 4] = [
|
||||
Vec::with_capacity(n),
|
||||
Vec::with_capacity(n),
|
||||
Vec::with_capacity(n),
|
||||
Vec::with_capacity(n),
|
||||
];
|
||||
for &hv in &hs {
|
||||
let hf = hv as usize; // floor (hv >= 0)
|
||||
let hc = (hf + 1).min(ngrid - 1);
|
||||
let dh = (hv - hf as f64) as f32;
|
||||
for &wv in &ws {
|
||||
let wf = wv as usize;
|
||||
let wc = (wf + 1).min(ngrid - 1);
|
||||
let dw = (wv - wf as f64) as f32;
|
||||
idx[0].push((hf * ngrid + wf) as u32);
|
||||
wts[0].push((1.0 - dh) * (1.0 - dw));
|
||||
idx[1].push((hf * ngrid + wc) as u32);
|
||||
wts[1].push((1.0 - dh) * dw);
|
||||
idx[2].push((hc * ngrid + wf) as u32);
|
||||
wts[2].push(dh * (1.0 - dw));
|
||||
idx[3].push((hc * ngrid + wc) as u32);
|
||||
wts[3].push(dh * dw);
|
||||
}
|
||||
}
|
||||
|
||||
// Blend in f32 and cast once at the end — the reference keeps
|
||||
// the bilinear weights f32 against bf16 table rows; rounding
|
||||
// the weights to bf16 first costs a visible slice of fixture
|
||||
// parity (#15).
|
||||
let mut acc: Option<Tensor> = None;
|
||||
for corner in 0..4 {
|
||||
let idx_t = Tensor::from_vec(std::mem::take(&mut idx[corner]), (n,), &self.device)?;
|
||||
let emb = self
|
||||
.pos_embed
|
||||
.forward(&idx_t)?
|
||||
.to_dtype(candle_core::DType::F32)?; // (n, hidden)
|
||||
let wt = Tensor::from_vec(std::mem::take(&mut wts[corner]), (n, 1), &self.device)?;
|
||||
let term = emb.broadcast_mul(&wt)?;
|
||||
acc = Some(match acc {
|
||||
Some(a) => a.add(&term)?,
|
||||
None => term,
|
||||
});
|
||||
}
|
||||
acc.expect("4 corners accumulated")
|
||||
.to_dtype(self.dtype)
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Encode one image.
|
||||
///
|
||||
/// `image`: row-major `(3, H, W)` f32 tensor on `self.device`,
|
||||
/// already normalised by `preprocess::preprocess`. Both `H` and
|
||||
/// `W` must be multiples of `patch_size * spatial_merge_size`.
|
||||
///
|
||||
/// Returns `(N_lm, out_hidden_size)` — LM-side image tokens
|
||||
/// ready to splice into the language model's input embeddings.
|
||||
pub fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let (c, h, w) = image.dims3()?;
|
||||
anyhow::ensure!(
|
||||
c == self.config.in_channels,
|
||||
"image must have {} channels, got {c}",
|
||||
self.config.in_channels
|
||||
);
|
||||
let patch = self.config.patch_size;
|
||||
anyhow::ensure!(
|
||||
h.is_multiple_of(patch) && w.is_multiple_of(patch),
|
||||
"image dims must be multiples of patch_size={patch}; got ({h}, {w})"
|
||||
);
|
||||
let gh = h / patch;
|
||||
let gw = w / patch;
|
||||
let n_patches = gh * gw;
|
||||
anyhow::ensure!(
|
||||
n_patches <= self.config.num_position_embeddings,
|
||||
"patch count {n_patches} exceeds pos_embed budget {}",
|
||||
self.config.num_position_embeddings
|
||||
);
|
||||
|
||||
// Add batch axis for conv: (1, 3, H, W) → (1, hidden, gh, gw)
|
||||
// → (hidden, gh, gw) → permute to (gh, gw, hidden) → flatten to (N, hidden)
|
||||
let x = image.unsqueeze(0)?.to_dtype(self.dtype)?;
|
||||
let x = self.patch_embed.forward(&x)?;
|
||||
let x = x.squeeze(0)?;
|
||||
let x = x.permute((1, 2, 0))?.contiguous()?;
|
||||
let x = x.reshape((n_patches, self.config.hidden_size))?;
|
||||
|
||||
// Learned absolute position embeddings. The `pos_embed` table is
|
||||
// a `num_position_embeddings = num_grid_per_side²` learned grid
|
||||
// (48×48 for Qwen3.6); for a `gh×gw` patch grid the reference
|
||||
// (`fast_pos_embed_interpolate`) bilinearly interpolates that
|
||||
// grid to `gh×gw`. The legacy path (a naive sequential lookup of
|
||||
// the first `n_patches` rows) mis-maps the grid stride and
|
||||
// scrambles spatial structure — kept only behind
|
||||
// `NEURON_VISION_LEGACY_POS=1` for A/B comparison.
|
||||
let pos = if vision_legacy_pos() {
|
||||
let positions = Tensor::arange(0u32, n_patches as u32, &self.device)?;
|
||||
self.pos_embed.forward(&positions)?
|
||||
} else {
|
||||
self.interpolated_pos_embed(gh, gw)?
|
||||
};
|
||||
let mut x = x.add(&pos)?;
|
||||
|
||||
// 2D vision rotary (row/col per patch), computed once and applied
|
||||
// in every block's attention. Legacy escape hatch skips it.
|
||||
let rotary = if vision_legacy_rope() {
|
||||
None
|
||||
} else {
|
||||
Some(self.rotary.cos_sin(gh, gw, &self.device, self.dtype)?)
|
||||
};
|
||||
let rotary_ref = rotary.as_ref();
|
||||
|
||||
for (i, block) in self.blocks.iter().enumerate() {
|
||||
x = block
|
||||
.forward(&x, rotary_ref)
|
||||
.with_context(|| format!("vision block {i}"))?;
|
||||
}
|
||||
|
||||
// (n_patches, hidden) → (gh, gw, hidden) for the merger.
|
||||
let x = x.reshape((gh, gw, self.config.hidden_size))?;
|
||||
self.merger.forward(&x)
|
||||
}
|
||||
}
|
||||
|
||||
/// Manually load a candle_nn LayerNorm from a ShardedVarBuilder.
|
||||
/// candle_nn's `layer_norm` builder takes `crate::VarBuilder`, not
|
||||
/// `ShardedVarBuilder`, so the existing arch modules in this crate
|
||||
/// uniformly do the manual load + struct construction pattern (see
|
||||
/// `full_attn::load_linear_no_bias`). We follow suit here.
|
||||
fn layer_norm(vb: ShardedVarBuilder, size: usize) -> Result<LayerNorm> {
|
||||
let weight = vb
|
||||
.get(size, "weight")
|
||||
.with_context(|| format!("load LayerNorm.weight at '{}'", vb.prefix()))?;
|
||||
let bias = vb
|
||||
.get(size, "bias")
|
||||
.with_context(|| format!("load LayerNorm.bias at '{}'", vb.prefix()))?;
|
||||
Ok(LayerNorm::new(weight, bias, LAYER_NORM_EPS))
|
||||
}
|
||||
|
||||
/// Manually load a candle_nn Linear (with bias) from a
|
||||
/// ShardedVarBuilder. Same rationale as `layer_norm` above.
|
||||
fn linear(vb: ShardedVarBuilder, in_dim: usize, out_dim: usize) -> Result<Linear> {
|
||||
let weight = vb
|
||||
.get((out_dim, in_dim), "weight")
|
||||
.with_context(|| format!("load Linear.weight at '{}'", vb.prefix()))?;
|
||||
let bias = vb
|
||||
.get(out_dim, "bias")
|
||||
.with_context(|| format!("load Linear.bias at '{}'", vb.prefix()))?;
|
||||
Ok(Linear::new(weight, Some(bias)))
|
||||
}
|
||||
|
||||
/// PyTorch's `gelu_pytorch_tanh` approximation — what the Qwen3.6
|
||||
/// vision tower's `hidden_act` specifies. candle's `Tensor::gelu`
|
||||
/// uses the exact erf-based GELU, so we compute the tanh
|
||||
/// approximation explicitly:
|
||||
///
|
||||
/// ```text
|
||||
/// gelu_tanh(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
/// ```
|
||||
fn gelu_tanh(x: &Tensor) -> Result<Tensor> {
|
||||
// sqrt(2 / pi) = 0.7978845608028654
|
||||
const COEFF: f64 = 0.7978845608028654;
|
||||
const KAPPA: f64 = 0.044715;
|
||||
let x3 = x.powf(3.0)?;
|
||||
let inner = (x + (x3 * KAPPA)?)?;
|
||||
let inner = (inner * COEFF)?;
|
||||
let t = inner.tanh()?;
|
||||
let one_plus_t = (t + 1.0)?;
|
||||
let out = (x * 0.5)?;
|
||||
let out = out.broadcast_mul(&one_plus_t)?;
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use candle_core::{DType, Device};
|
||||
|
||||
/// Build a tiny VisionConfig usable on CPU with random weights.
|
||||
/// Match the Qwen3.6 shape relations (depth-N stack, hidden mod
|
||||
/// num_heads, intermediate_size > hidden_size) but with small
|
||||
/// dims so tests run in milliseconds.
|
||||
fn tiny_config() -> VisionConfig {
|
||||
VisionConfig {
|
||||
depth: 2,
|
||||
hidden_size: 32,
|
||||
intermediate_size: 64,
|
||||
num_heads: 4,
|
||||
num_position_embeddings: 64,
|
||||
patch_size: 4,
|
||||
temporal_patch_size: 2,
|
||||
spatial_merge_size: 2,
|
||||
in_channels: 3,
|
||||
out_hidden_size: 48,
|
||||
}
|
||||
}
|
||||
|
||||
/// Hand-construct a VisionTower with random weights. This is the
|
||||
/// same trick `linear_attn::tests::forward_smoke_with_tiny_dimensions`
|
||||
/// uses — bypass the safetensors-backed `ShardedVarBuilder` path
|
||||
/// (which can't be built from in-memory tensors) and assemble the
|
||||
/// struct fields directly. The real `VisionTower::load` is
|
||||
/// exercised by the cuda-integration smoke test in Stage A6.
|
||||
fn tiny_tower(cfg: &VisionConfig) -> VisionTower {
|
||||
let device = Device::Cpu;
|
||||
let dtype = DType::F32;
|
||||
let zeros = |shape: &[usize]| Tensor::zeros(shape, dtype, &device).unwrap();
|
||||
let ones = |shape: &[usize]| Tensor::ones(shape, dtype, &device).unwrap();
|
||||
let randn = |shape: &[usize]| Tensor::randn(0_f32, 0.02, shape, &device).unwrap();
|
||||
|
||||
let patch_embed = Conv2d::new(
|
||||
randn(&[
|
||||
cfg.hidden_size,
|
||||
cfg.in_channels,
|
||||
cfg.patch_size,
|
||||
cfg.patch_size,
|
||||
]),
|
||||
Some(zeros(&[cfg.hidden_size])),
|
||||
Conv2dConfig {
|
||||
stride: cfg.patch_size,
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
let pos_embed = Embedding::new(
|
||||
randn(&[cfg.num_position_embeddings, cfg.hidden_size]),
|
||||
cfg.hidden_size,
|
||||
);
|
||||
|
||||
let mut blocks = Vec::with_capacity(cfg.depth);
|
||||
for _ in 0..cfg.depth {
|
||||
let head_dim = cfg.hidden_size / cfg.num_heads;
|
||||
blocks.push(VisionBlock {
|
||||
norm1: LayerNorm::new(
|
||||
ones(&[cfg.hidden_size]),
|
||||
zeros(&[cfg.hidden_size]),
|
||||
LAYER_NORM_EPS,
|
||||
),
|
||||
qkv: Linear::new(
|
||||
randn(&[3 * cfg.hidden_size, cfg.hidden_size]),
|
||||
Some(zeros(&[3 * cfg.hidden_size])),
|
||||
),
|
||||
proj: Linear::new(
|
||||
randn(&[cfg.hidden_size, cfg.hidden_size]),
|
||||
Some(zeros(&[cfg.hidden_size])),
|
||||
),
|
||||
norm2: LayerNorm::new(
|
||||
ones(&[cfg.hidden_size]),
|
||||
zeros(&[cfg.hidden_size]),
|
||||
LAYER_NORM_EPS,
|
||||
),
|
||||
fc1: Linear::new(
|
||||
randn(&[cfg.intermediate_size, cfg.hidden_size]),
|
||||
Some(zeros(&[cfg.intermediate_size])),
|
||||
),
|
||||
fc2: Linear::new(
|
||||
randn(&[cfg.hidden_size, cfg.intermediate_size]),
|
||||
Some(zeros(&[cfg.hidden_size])),
|
||||
),
|
||||
num_heads: cfg.num_heads,
|
||||
head_dim,
|
||||
});
|
||||
}
|
||||
|
||||
let merge_input_dim = cfg.hidden_size * cfg.spatial_merge_size * cfg.spatial_merge_size;
|
||||
let merger = VisionMerger {
|
||||
norm: LayerNorm::new(
|
||||
ones(&[cfg.hidden_size]),
|
||||
zeros(&[cfg.hidden_size]),
|
||||
LAYER_NORM_EPS,
|
||||
),
|
||||
fc1: Linear::new(
|
||||
randn(&[merge_input_dim, merge_input_dim]),
|
||||
Some(zeros(&[merge_input_dim])),
|
||||
),
|
||||
fc2: Linear::new(
|
||||
randn(&[cfg.out_hidden_size, merge_input_dim]),
|
||||
Some(zeros(&[cfg.out_hidden_size])),
|
||||
),
|
||||
merge_input_dim,
|
||||
spatial_merge_size: cfg.spatial_merge_size,
|
||||
};
|
||||
|
||||
let rotary = VisionRotaryEmbedding::new(cfg.hidden_size / cfg.num_heads);
|
||||
VisionTower {
|
||||
patch_embed,
|
||||
pos_embed,
|
||||
rotary,
|
||||
blocks,
|
||||
merger,
|
||||
config: cfg.clone(),
|
||||
dtype,
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_with_random_weights_produces_finite_output() {
|
||||
let cfg = tiny_config();
|
||||
let tower = tiny_tower(&cfg);
|
||||
|
||||
// 16×16 image at patch_size=4 → 4×4 patches → after 2×2
|
||||
// merge → 2×2 = 4 LM tokens of dim out_hidden_size.
|
||||
let image = Tensor::randn(0_f32, 1.0, (3, 16, 16), &Device::Cpu).unwrap();
|
||||
let out = tower.forward(&image).expect("forward");
|
||||
let (n_lm, hidden) = out.dims2().unwrap();
|
||||
assert_eq!(n_lm, 4);
|
||||
assert_eq!(hidden, cfg.out_hidden_size);
|
||||
|
||||
// No NaN/Inf
|
||||
let values: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
|
||||
assert!(
|
||||
values.iter().all(|v| v.is_finite()),
|
||||
"forward must produce finite values"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolated_pos_embed_reduces_to_sequential_at_native_grid() {
|
||||
// When the patch grid equals the pos_embed grid (gh=gw=ngrid),
|
||||
// linspace(0,ngrid-1,ngrid) is the integer ladder, so every patch
|
||||
// lands exactly on a grid node (dh=dw=0, corner-0 weight 1) and
|
||||
// the bilinear result is the raw pos_embed rows in row-major
|
||||
// order — i.e. identical to the legacy sequential lookup.
|
||||
let cfg = tiny_config();
|
||||
let tower = tiny_tower(&cfg);
|
||||
let ngrid = (cfg.num_position_embeddings as f64).sqrt() as usize; // 8
|
||||
let interp = tower.interpolated_pos_embed(ngrid, ngrid).unwrap();
|
||||
let seq = tower
|
||||
.pos_embed
|
||||
.forward(&Tensor::arange(0u32, (ngrid * ngrid) as u32, &Device::Cpu).unwrap())
|
||||
.unwrap();
|
||||
let a: Vec<f32> = interp.flatten_all().unwrap().to_vec1().unwrap();
|
||||
let b: Vec<f32> = seq.flatten_all().unwrap().to_vec1().unwrap();
|
||||
assert_eq!(a.len(), b.len());
|
||||
for (x, y) in a.iter().zip(b.iter()) {
|
||||
assert!((x - y).abs() < 1e-5, "interp {x} vs seq {y}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vision_rotary_row_col_structure() {
|
||||
// head_dim 8 → rotary dim 4 → inv_freq over [0,2] → 2 freqs/axis.
|
||||
let rot = VisionRotaryEmbedding::new(8);
|
||||
assert_eq!(rot.inv_freq.len(), 2);
|
||||
let (cos, sin) = rot.cos_sin(2, 2, &Device::Cpu, DType::F32).unwrap();
|
||||
assert_eq!(cos.dims(), &[4, 4]); // 4 patches, head_dim/2 = 4 cols
|
||||
|
||||
// Patch (0,0): all freqs 0 → cos 1, sin 0.
|
||||
let s0: Vec<f32> = sin.i(0).unwrap().to_vec1().unwrap();
|
||||
assert!(s0.iter().all(|&s| s.abs() < 1e-6));
|
||||
|
||||
// Patch index 2 = grid (1,0): row=1 drives the first half, col=0
|
||||
// leaves the second half at zero.
|
||||
let s2: Vec<f32> = sin.i(2).unwrap().to_vec1().unwrap();
|
||||
assert!(s2[0].abs() > 1e-6, "row half must be non-zero");
|
||||
assert!(
|
||||
s2[2].abs() < 1e-6 && s2[3].abs() < 1e-6,
|
||||
"col half must be zero"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lm_token_count_matches_grid() {
|
||||
let cfg = tiny_config();
|
||||
let tower = tiny_tower(&cfg);
|
||||
// 16x16 image → 4x4 patches → 2x2 = 4 LM tokens
|
||||
assert_eq!(tower.lm_tokens_for(16, 16), 4);
|
||||
// 32x32 image → 8x8 patches → 4x4 = 16 LM tokens
|
||||
assert_eq!(tower.lm_tokens_for(32, 32), 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_image_with_dims_not_multiple_of_patch() {
|
||||
let cfg = tiny_config();
|
||||
let tower = tiny_tower(&cfg);
|
||||
let image = Tensor::randn(0_f32, 1.0, (3, 17, 17), &Device::Cpu).unwrap();
|
||||
let err = tower.forward(&image).unwrap_err();
|
||||
assert!(format!("{err:#}").contains("patch_size"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_image_with_wrong_channel_count() {
|
||||
let cfg = tiny_config();
|
||||
let tower = tiny_tower(&cfg);
|
||||
let image = Tensor::randn(0_f32, 1.0, (4, 16, 16), &Device::Cpu).unwrap();
|
||||
let err = tower.forward(&image).unwrap_err();
|
||||
assert!(format!("{err:#}").contains("channels"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gelu_tanh_matches_known_values() {
|
||||
// Reference values for gelu_pytorch_tanh from PyTorch:
|
||||
// gelu_tanh(0.0) = 0.0
|
||||
// gelu_tanh(1.0) ≈ 0.8411920071
|
||||
// gelu_tanh(-1.0) ≈ -0.1588079929
|
||||
let x = Tensor::new(&[0.0_f32, 1.0, -1.0], &Device::Cpu).unwrap();
|
||||
let y = gelu_tanh(&x).unwrap();
|
||||
let v: Vec<f32> = y.to_vec1().unwrap();
|
||||
assert!((v[0]).abs() < 1e-6, "gelu_tanh(0) ≈ 0, got {}", v[0]);
|
||||
assert!(
|
||||
(v[1] - 0.841_192_f32).abs() < 1e-5,
|
||||
"gelu_tanh(1) ≈ 0.84119, got {}",
|
||||
v[1]
|
||||
);
|
||||
assert!(
|
||||
(v[2] - -0.158_808_f32).abs() < 1e-5,
|
||||
"gelu_tanh(-1) ≈ -0.15881, got {}",
|
||||
v[2]
|
||||
);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,562 +0,0 @@
|
||||
//! Chat-template rendering for the model-supplied Jinja templates
|
||||
//! HuggingFace tokenizers ship in `tokenizer_config.json`.
|
||||
//!
|
||||
//! ## Background
|
||||
//!
|
||||
//! Every modern open-weight model bundles a `chat_template` field
|
||||
//! in its `tokenizer_config.json` — a Jinja2 template string that
|
||||
//! converts a sequence of `{role, content}` messages into the
|
||||
//! exact prompt the model was trained on. Examples:
|
||||
//!
|
||||
//! - Qwen3-Coder: `<|im_start|>{role}\n{content}<|im_end|>\n…`
|
||||
//! with conditional `enable_thinking` handling that injects an
|
||||
//! empty `<think>\n\n</think>` block when set false.
|
||||
//! - DeepSeek-R1: similar im_start framing with different special-
|
||||
//! token names.
|
||||
//! - Mistral / Magistral: a `[INST]` / `[/INST]` framing.
|
||||
//! - Claude / Llama: another shape again.
|
||||
//!
|
||||
//! Rendering the model's own template is the only way to get the
|
||||
//! *exact* prompt format the model was trained on plus the
|
||||
//! model-specific kwargs (`enable_thinking`, `tools`, …) without
|
||||
//! hardcoding per-model logic. The alternative — neuron's previous
|
||||
//! `format_qwen3_prompt` — was a hardcoded Qwen3 ChatML glue that
|
||||
//! ignored kwargs entirely.
|
||||
//!
|
||||
//! ## Scope
|
||||
//!
|
||||
//! This module is request-side only: it builds the prompt string
|
||||
//! the tokenizer ingests before inference. The reasoning- and
|
||||
//! tool-call-marker token routing (issues #6, #8) is response-side
|
||||
//! and stays in `wire::openai_chat` / the streaming inference
|
||||
//! loops.
|
||||
//!
|
||||
//! ## Fallback
|
||||
//!
|
||||
//! When the model's `tokenizer_config.json` is missing, doesn't
|
||||
//! parse, lacks a `chat_template`, or renders an error, the caller
|
||||
//! falls back to `format_qwen3_prompt`. The
|
||||
//! `NEURON_USE_CHAT_TEMPLATE=false` env var is a global kill
|
||||
//! switch — if a deploy goes sideways and the renderer is to
|
||||
//! blame, an operator can flip the env and restart neuron without
|
||||
//! shipping a new build.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use cortex_core::openai::{ChatMessage, MessageContent};
|
||||
use minijinja::{Environment, Error as MjError, ErrorKind as MjErrorKind, Value as MjValue};
|
||||
use serde_json::Value;
|
||||
use std::path::Path;
|
||||
|
||||
/// Environment variable that, when set to `false`/`0`/`no`,
|
||||
/// forces every model to skip its `chat_template` and fall back
|
||||
/// to `format_qwen3_prompt`. Default (unset) is "use chat
|
||||
/// templates where available".
|
||||
pub const KILL_SWITCH_ENV: &str = "NEURON_USE_CHAT_TEMPLATE";
|
||||
|
||||
/// Read the global kill switch. `true` means chat templates are
|
||||
/// enabled; `false` forces the fallback path everywhere.
|
||||
pub fn chat_templates_enabled() -> bool {
|
||||
match std::env::var(KILL_SWITCH_ENV).ok().as_deref() {
|
||||
Some(s) => !matches!(
|
||||
s.trim().to_ascii_lowercase().as_str(),
|
||||
"false" | "0" | "no" | "off"
|
||||
),
|
||||
None => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Probe for the model's chat template in the same directory the
|
||||
/// tokenizer was loaded from, following HuggingFace `transformers`
|
||||
/// precedence: a standalone `chat_template.jinja` (then
|
||||
/// `chat_template.json`) wins over the `chat_template` field in
|
||||
/// `tokenizer_config.json`.
|
||||
///
|
||||
/// This matters for multimodal models: Qwen3-VL / Qwen3.6 ship their
|
||||
/// vision-aware template (the one that emits
|
||||
/// `<|vision_start|><|image_pad|><|vision_end|>` per image) **only** in
|
||||
/// `chat_template.jinja`, and may not ship a `tokenizer_config.json` at
|
||||
/// all. Reading `tokenizer_config.json` alone returned `None`, which
|
||||
/// dropped image content into the text-only `format_qwen3_prompt`
|
||||
/// fallback — so image requests rendered zero `<|image_pad|>` tokens
|
||||
/// and the vision path bailed on the count mismatch.
|
||||
pub fn load_chat_template_alongside(tokenizer_json_path: &Path) -> Option<String> {
|
||||
let parent = tokenizer_json_path.parent()?;
|
||||
|
||||
// 1. Standalone Jinja file — raw template text, highest priority.
|
||||
let jinja_path = parent.join("chat_template.jinja");
|
||||
match std::fs::read_to_string(&jinja_path) {
|
||||
Ok(text) if !text.trim().is_empty() => {
|
||||
tracing::info!(
|
||||
path = %jinja_path.display(),
|
||||
"chat_template: loaded standalone chat_template.jinja"
|
||||
);
|
||||
return Some(text);
|
||||
}
|
||||
Ok(_) => {
|
||||
tracing::warn!(
|
||||
path = %jinja_path.display(),
|
||||
"chat_template: chat_template.jinja present but empty; trying other sources"
|
||||
);
|
||||
}
|
||||
Err(_) => {} // absent — fall through, common case
|
||||
}
|
||||
|
||||
// 2. Standalone JSON file — `{"chat_template": "..."}` form.
|
||||
let json_path = parent.join("chat_template.json");
|
||||
if json_path.exists()
|
||||
&& let Some(t) = load_chat_template_from(&json_path)
|
||||
{
|
||||
tracing::info!(
|
||||
path = %json_path.display(),
|
||||
"chat_template: loaded standalone chat_template.json"
|
||||
);
|
||||
return Some(t);
|
||||
}
|
||||
|
||||
// 3. The `chat_template` field inside tokenizer_config.json.
|
||||
let config_path = parent.join("tokenizer_config.json");
|
||||
load_chat_template_from(&config_path)
|
||||
}
|
||||
|
||||
/// Best-effort load of `chat_template` from a HuggingFace
|
||||
/// `tokenizer_config.json`. Returns `None` when the file is
|
||||
/// absent, doesn't parse, or lacks the `chat_template` field —
|
||||
/// in all of those cases the caller falls back to
|
||||
/// `format_qwen3_prompt`. Warnings are logged so an operator can
|
||||
/// see why the fallback fired.
|
||||
pub fn load_chat_template_from(path: &Path) -> Option<String> {
|
||||
let text = match std::fs::read_to_string(path) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
path = %path.display(),
|
||||
error = %e,
|
||||
"chat_template: tokenizer_config.json absent or unreadable; falling back"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
let value: Value = match serde_json::from_str(&text) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
path = %path.display(),
|
||||
error = %e,
|
||||
"chat_template: tokenizer_config.json failed to parse; falling back"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
// Some tokenizer_config.json files carry `chat_template` as an
|
||||
// array of `{name, template}` objects (multi-template models —
|
||||
// tool-use variant, default variant). For now we pick the first
|
||||
// entry; future iterations could honour a name hint.
|
||||
match value.get("chat_template") {
|
||||
Some(Value::String(s)) => Some(s.clone()),
|
||||
Some(Value::Array(arr)) => {
|
||||
for entry in arr {
|
||||
if let Some(t) = entry.get("template").and_then(|v| v.as_str()) {
|
||||
return Some(t.to_string());
|
||||
}
|
||||
}
|
||||
tracing::warn!(
|
||||
path = %path.display(),
|
||||
"chat_template: array form had no usable template entry; falling back"
|
||||
);
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Render the chat template into the prompt the model expects.
|
||||
///
|
||||
/// `template` is the raw Jinja string from `tokenizer_config.json`.
|
||||
/// `messages` is the conversation in order. `kwargs` is the
|
||||
/// `chat_template_kwargs` object the client supplied on the
|
||||
/// request (or `Value::Null` when absent). The function expands
|
||||
/// the kwargs into the Jinja context alongside the standard
|
||||
/// `messages` and `add_generation_prompt` variables HF templates
|
||||
/// expect.
|
||||
///
|
||||
/// `tools` is the request's `tools` array (or `Value::Null`).
|
||||
/// Some chat templates iterate it to emit native tool definitions
|
||||
/// (Qwen3-Coder's tool-use template, Mistral's [TOOL_DEFINITIONS]
|
||||
/// frame). We forward whatever the client sent without
|
||||
/// interpretation.
|
||||
pub fn render_chat_template(
|
||||
template: &str,
|
||||
messages: &[ChatMessage],
|
||||
tools: &Value,
|
||||
kwargs: &Value,
|
||||
) -> Result<String> {
|
||||
let mut env = Environment::new();
|
||||
|
||||
// HF chat templates are authored against Python's Jinja2 with its
|
||||
// string semantics. Bridge the two so real model templates render:
|
||||
//
|
||||
// - `pycompat::unknown_method_callback` supplies Python str/list/dict
|
||||
// methods minijinja lacks natively (`startswith`, `endswith`,
|
||||
// `split`, `rstrip`, `lstrip`, …) — the Qwen3.6 template uses
|
||||
// several in its think-block and tool-response handling.
|
||||
// - `raise_exception` is the global HF templates call to reject
|
||||
// malformed inputs (e.g. an image in a system message). Map it to
|
||||
// a render error so the caller falls back / surfaces it.
|
||||
env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
|
||||
env.add_function(
|
||||
"raise_exception",
|
||||
|msg: String| -> Result<MjValue, MjError> {
|
||||
Err(MjError::new(MjErrorKind::InvalidOperation, msg))
|
||||
},
|
||||
);
|
||||
|
||||
// Compile the template against a fixed name so error messages
|
||||
// surface "chat_template" rather than `<template>`.
|
||||
env.add_template("chat_template", template)
|
||||
.context("compile chat_template")?;
|
||||
let tmpl = env.get_template("chat_template").unwrap();
|
||||
|
||||
// Convert our internal ChatMessage shape into the
|
||||
// `[{role, content}]` shape HF templates iterate. Text content
|
||||
// becomes a string; Parts becomes an array of content blocks.
|
||||
// The HF templates handle both shapes via `content is string`
|
||||
// checks or content-array iteration.
|
||||
let messages_json: Vec<Value> = messages
|
||||
.iter()
|
||||
.map(|m| {
|
||||
let content_value = match &m.content {
|
||||
MessageContent::Text(s) => Value::String(s.clone()),
|
||||
MessageContent::Parts(parts) => Value::Array(parts.clone()),
|
||||
};
|
||||
let mut obj = serde_json::Map::new();
|
||||
obj.insert("role".into(), Value::String(m.role.clone()));
|
||||
obj.insert("content".into(), content_value);
|
||||
// Forward extras (e.g. tool_calls on assistant turns,
|
||||
// tool_call_id on tool result turns). HF templates that
|
||||
// need them read e.g. `message.tool_calls`.
|
||||
if let Value::Object(extras) = &m.extra {
|
||||
for (k, v) in extras {
|
||||
obj.insert(k.clone(), v.clone());
|
||||
}
|
||||
}
|
||||
Value::Object(obj)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Build the kwargs context. Add base bindings the template
|
||||
// expects (`messages`, `add_generation_prompt`, `tools`) plus
|
||||
// anything the caller passed in `chat_template_kwargs`. Caller
|
||||
// kwargs override the defaults so `add_generation_prompt: false`
|
||||
// from the request actually wins.
|
||||
let mut ctx_map = serde_json::Map::new();
|
||||
ctx_map.insert("messages".into(), Value::Array(messages_json));
|
||||
ctx_map.insert("add_generation_prompt".into(), Value::Bool(true));
|
||||
if !tools.is_null() {
|
||||
ctx_map.insert("tools".into(), tools.clone());
|
||||
}
|
||||
if let Value::Object(kwargs_obj) = kwargs {
|
||||
for (k, v) in kwargs_obj {
|
||||
ctx_map.insert(k.clone(), v.clone());
|
||||
}
|
||||
}
|
||||
// `Template::render` takes any Serialize value; serde_json's
|
||||
// `Value` implements it natively, so we pass the assembled
|
||||
// context object directly without going through the
|
||||
// `context!` macro (which expects minijinja-native values).
|
||||
tmpl.render(Value::Object(ctx_map))
|
||||
.context("render chat_template")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
/// Reproduces the Qwen3.6 vision template's image-insertion
|
||||
/// condition against the OpenAI `image_url` content-part shape our
|
||||
/// renderer forwards. Confirms minijinja's `'image_url' in item`
|
||||
/// matches a serde_json object that carries that key — i.e. the
|
||||
/// template *can* emit `<|image_pad|>` for our parts.
|
||||
#[test]
|
||||
fn image_url_part_renders_image_pad() {
|
||||
// Condition copied from doc/vision-qwen3_6-spec.md (lines 8-18
|
||||
// of the real chat_template.jinja).
|
||||
let template = "{%- for message in messages -%}\
|
||||
{%- if message.content is string -%}\
|
||||
{{ message.content }}\
|
||||
{%- else -%}\
|
||||
{%- for item in message.content -%}\
|
||||
{%- if 'image' in item or 'image_url' in item or item.type == 'image' -%}\
|
||||
<|vision_start|><|image_pad|><|vision_end|>\
|
||||
{%- elif item.type == 'text' -%}\
|
||||
{{ item.text }}\
|
||||
{%- endif -%}\
|
||||
{%- endfor -%}\
|
||||
{%- endif -%}\
|
||||
{%- endfor -%}";
|
||||
let messages = vec![ChatMessage {
|
||||
role: "user".into(),
|
||||
content: MessageContent::Parts(vec![
|
||||
json!({"type": "text", "text": "what is this?"}),
|
||||
json!({"type": "image_url", "image_url": {"url": "data:image/png;base64,AAA="}}),
|
||||
]),
|
||||
extra: Value::Object(Default::default()),
|
||||
}];
|
||||
let out = render_chat_template(template, &messages, &Value::Null, &Value::Null)
|
||||
.expect("render should succeed");
|
||||
assert!(
|
||||
out.contains("<|image_pad|>"),
|
||||
"expected the image_url part to emit <|image_pad|>; rendered: {out:?}"
|
||||
);
|
||||
}
|
||||
|
||||
/// `chat_template.jinja` must win over `tokenizer_config.json`'s
|
||||
/// `chat_template` field — the transformers precedence Qwen3.6
|
||||
/// relies on (its vision template ships only in the `.jinja` file).
|
||||
#[test]
|
||||
fn standalone_jinja_template_takes_precedence() {
|
||||
let dir = std::env::temp_dir().join(format!(
|
||||
"neuron_ct_precedence_{}_{}",
|
||||
std::process::id(),
|
||||
line!()
|
||||
));
|
||||
std::fs::create_dir_all(&dir).unwrap();
|
||||
std::fs::write(dir.join("chat_template.jinja"), "FROM_JINJA").unwrap();
|
||||
std::fs::write(
|
||||
dir.join("tokenizer_config.json"),
|
||||
r#"{"chat_template": "FROM_CONFIG"}"#,
|
||||
)
|
||||
.unwrap();
|
||||
// tokenizer_json_path is the sibling the loader takes a parent of.
|
||||
let got = load_chat_template_alongside(&dir.join("tokenizer.json"));
|
||||
std::fs::remove_dir_all(&dir).ok();
|
||||
assert_eq!(got.as_deref(), Some("FROM_JINJA"));
|
||||
}
|
||||
|
||||
/// With no standalone file, fall back to the tokenizer_config.json
|
||||
/// field — the text-only path stays unchanged.
|
||||
#[test]
|
||||
fn falls_back_to_tokenizer_config_when_no_standalone() {
|
||||
let dir = std::env::temp_dir().join(format!(
|
||||
"neuron_ct_fallback_{}_{}",
|
||||
std::process::id(),
|
||||
line!()
|
||||
));
|
||||
std::fs::create_dir_all(&dir).unwrap();
|
||||
std::fs::write(
|
||||
dir.join("tokenizer_config.json"),
|
||||
r#"{"chat_template": "FROM_CONFIG"}"#,
|
||||
)
|
||||
.unwrap();
|
||||
let got = load_chat_template_alongside(&dir.join("tokenizer.json"));
|
||||
std::fs::remove_dir_all(&dir).ok();
|
||||
assert_eq!(got.as_deref(), Some("FROM_CONFIG"));
|
||||
}
|
||||
|
||||
/// The *actual* Qwen3.6-27B `chat_template.jinja` (verbatim from
|
||||
/// beast's HF cache) must render in minijinja and emit exactly one
|
||||
/// `<|image_pad|>` for a text+image user turn. This is the real
|
||||
/// end-to-end check the unit tests above only approximate — it
|
||||
/// catches any minijinja incompatibility (namespace, macros,
|
||||
/// reverse slice, string methods) before it reaches production.
|
||||
#[test]
|
||||
fn real_qwen3_6_template_renders_one_image_pad() {
|
||||
let template = include_str!("testdata/qwen3_6_chat_template.jinja");
|
||||
let messages = vec![ChatMessage {
|
||||
role: "user".into(),
|
||||
content: MessageContent::Parts(vec![
|
||||
json!({"type": "text", "text": "what is this?"}),
|
||||
json!({"type": "image_url", "image_url": {"url": "data:image/png;base64,AAA="}}),
|
||||
]),
|
||||
extra: Value::Object(Default::default()),
|
||||
}];
|
||||
let out = render_chat_template(template, &messages, &Value::Null, &Value::Null)
|
||||
.expect("real Qwen3.6 template should render in minijinja");
|
||||
let pads = out.matches("<|image_pad|>").count();
|
||||
assert_eq!(
|
||||
pads, 1,
|
||||
"expected exactly one <|image_pad|>; rendered:\n{out}"
|
||||
);
|
||||
assert!(out.contains("<|vision_start|>") && out.contains("<|vision_end|>"));
|
||||
}
|
||||
|
||||
fn user_msg(text: &str) -> ChatMessage {
|
||||
ChatMessage {
|
||||
role: "user".into(),
|
||||
content: MessageContent::Text(text.into()),
|
||||
extra: Value::Object(Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn assistant_msg(text: &str) -> ChatMessage {
|
||||
ChatMessage {
|
||||
role: "assistant".into(),
|
||||
content: MessageContent::Text(text.into()),
|
||||
extra: Value::Object(Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Minimal Qwen3-style template — enough surface to confirm
|
||||
/// our renderer threads role + content correctly without
|
||||
/// loading a real model's tokenizer_config.json.
|
||||
const QWEN3_LIKE: &str = "{%- for message in messages -%}\
|
||||
<|im_start|>{{ message.role }}\n{{ message.content }}<|im_end|>\n\
|
||||
{%- endfor -%}\
|
||||
{%- if add_generation_prompt -%}<|im_start|>assistant\n{%- endif -%}";
|
||||
|
||||
#[test]
|
||||
fn renders_basic_conversation() {
|
||||
let prompt = render_chat_template(
|
||||
QWEN3_LIKE,
|
||||
&[user_msg("hello"), assistant_msg("hi"), user_msg("bye")],
|
||||
&Value::Null,
|
||||
&Value::Null,
|
||||
)
|
||||
.unwrap();
|
||||
// Structural assertions — the exact whitespace produced
|
||||
// by a given template is a Jinja-trim concern that varies
|
||||
// per real chat_template. What matters is that every
|
||||
// turn's role + content thread through in order, and that
|
||||
// the generation cue lands at the end.
|
||||
assert!(
|
||||
prompt.contains("<|im_start|>user\nhello<|im_end|>"),
|
||||
"first user turn missing: {prompt}"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains("<|im_start|>assistant\nhi<|im_end|>"),
|
||||
"assistant turn missing: {prompt}"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains("<|im_start|>user\nbye<|im_end|>"),
|
||||
"second user turn missing: {prompt}"
|
||||
);
|
||||
assert!(
|
||||
prompt.ends_with("<|im_start|>assistant")
|
||||
|| prompt.ends_with("<|im_start|>assistant\n"),
|
||||
"generation cue missing at end: {prompt}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kwargs_are_threaded_into_template_context() {
|
||||
// Replica of Qwen3's enable_thinking branch in
|
||||
// simplified form. When the kwarg is false, the model's
|
||||
// template injects an empty `<think>...</think>` block
|
||||
// before the generation cue — pre-filling the model's
|
||||
// reasoning slot with "no thinking" so the model emits
|
||||
// the answer directly.
|
||||
let template = "{%- if enable_thinking is defined and enable_thinking is false -%}\
|
||||
NO_THINK\
|
||||
{%- else -%}\
|
||||
THINK_OK\
|
||||
{%- endif -%}";
|
||||
let r_disabled = render_chat_template(
|
||||
template,
|
||||
&[],
|
||||
&Value::Null,
|
||||
&json!({ "enable_thinking": false }),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(r_disabled, "NO_THINK");
|
||||
let r_default = render_chat_template(template, &[], &Value::Null, &Value::Null).unwrap();
|
||||
assert_eq!(r_default, "THINK_OK");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_template_field_returns_none() {
|
||||
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-missing-field.json");
|
||||
std::fs::write(&tmp, r#"{"some_other_field": 1}"#).unwrap();
|
||||
assert!(load_chat_template_from(&tmp).is_none());
|
||||
let _ = std::fs::remove_file(tmp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_template_from_string_field() {
|
||||
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-string.json");
|
||||
std::fs::write(
|
||||
&tmp,
|
||||
r#"{"chat_template": "hello {{ messages[0].content }}"}"#,
|
||||
)
|
||||
.unwrap();
|
||||
let t = load_chat_template_from(&tmp).expect("template loaded");
|
||||
assert!(t.contains("messages[0].content"));
|
||||
let _ = std::fs::remove_file(tmp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_template_from_array_form() {
|
||||
// Some HF models ship `chat_template` as `[{name, template}, ...]`.
|
||||
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-array.json");
|
||||
std::fs::write(
|
||||
&tmp,
|
||||
r#"{"chat_template": [{"name": "default", "template": "ARR"}]}"#,
|
||||
)
|
||||
.unwrap();
|
||||
let t = load_chat_template_from(&tmp).expect("template loaded");
|
||||
assert_eq!(t, "ARR");
|
||||
let _ = std::fs::remove_file(tmp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_file_returns_none_quietly() {
|
||||
let absent = std::path::PathBuf::from("/definitely/not/a/real/path.json");
|
||||
assert!(load_chat_template_from(&absent).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unparseable_returns_none() {
|
||||
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-garbage.json");
|
||||
std::fs::write(&tmp, b"{not valid json").unwrap();
|
||||
assert!(load_chat_template_from(&tmp).is_none());
|
||||
let _ = std::fs::remove_file(tmp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kill_switch_recognises_truthy_falsy_values() {
|
||||
// Test against the actual env var so callers see the
|
||||
// same behaviour as production. Serialise via a
|
||||
// mutex — see path_util.rs for the pattern.
|
||||
use std::sync::Mutex;
|
||||
static LOCK: Mutex<()> = Mutex::new(());
|
||||
let _g = LOCK.lock().unwrap();
|
||||
let prior = std::env::var(KILL_SWITCH_ENV).ok();
|
||||
unsafe {
|
||||
std::env::remove_var(KILL_SWITCH_ENV);
|
||||
}
|
||||
assert!(chat_templates_enabled());
|
||||
for value in ["false", "0", "no", "off", "FALSE", " no "] {
|
||||
unsafe { std::env::set_var(KILL_SWITCH_ENV, value) };
|
||||
assert!(!chat_templates_enabled(), "value {value:?} should disable");
|
||||
}
|
||||
for value in ["true", "1", "yes", ""] {
|
||||
unsafe { std::env::set_var(KILL_SWITCH_ENV, value) };
|
||||
assert!(chat_templates_enabled(), "value {value:?} should enable");
|
||||
}
|
||||
unsafe {
|
||||
match prior {
|
||||
Some(p) => std::env::set_var(KILL_SWITCH_ENV, p),
|
||||
None => std::env::remove_var(KILL_SWITCH_ENV),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn message_extras_thread_through_for_tool_calls() {
|
||||
// HF templates read assistant.tool_calls and tool
|
||||
// turns' tool_call_id. Confirm our extras flatten into
|
||||
// the message object the template iterates.
|
||||
let mut extras = serde_json::Map::new();
|
||||
extras.insert(
|
||||
"tool_calls".into(),
|
||||
json!([{"id": "t1", "function": {"name": "x", "arguments": "{}"}}]),
|
||||
);
|
||||
let msg = ChatMessage {
|
||||
role: "assistant".into(),
|
||||
content: MessageContent::Text(String::new()),
|
||||
extra: Value::Object(extras),
|
||||
};
|
||||
let template = "{{ messages[0].tool_calls[0].id }}";
|
||||
let rendered = render_chat_template(template, &[msg], &Value::Null, &Value::Null).unwrap();
|
||||
assert_eq!(rendered, "t1");
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,330 +0,0 @@
|
||||
//! Job variants accepted by the per-device worker thread.
|
||||
//!
|
||||
//! Each variant carries the inputs the synchronous dispatch handler
|
||||
//! needs plus a `tokio::sync::oneshot::Sender` for the reply. The
|
||||
//! async-side `DeviceWorkerHandle` constructs a job, sends it down the
|
||||
//! `std::sync::mpsc` channel, and `await`s the oneshot for the reply.
|
||||
|
||||
use anyhow::Result;
|
||||
use std::path::PathBuf;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
/// Opaque handle to a `ModelArch` stored in the worker thread's state
|
||||
/// slab. Cheap to copy; `Send + Sync` so it crosses task boundaries
|
||||
/// freely. The actual `Box<ModelArch>` it points to is owned by the
|
||||
/// worker thread for the duration of the handle's lifetime — the only
|
||||
/// way to drop the model is to send `Job::DropArch { handle }` so the
|
||||
/// `Drop` impl runs on the thread with the bound CUDA context (the
|
||||
/// invariant the whole refactor exists to guarantee).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct ArchHandle(pub u64);
|
||||
|
||||
/// Opaque handle to a `TpLeaderModel` stored in the worker thread's
|
||||
/// state slab. Same shape as [`ArchHandle`] but in a separate
|
||||
/// namespace so the two slabs can coexist without ambiguity. Phase 3
|
||||
/// introduces it; Phase 4 may unify the two slabs after the TP forward
|
||||
/// path proves out.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct TpHandle(pub u64);
|
||||
|
||||
/// Opaque handle to a prefix-cache snapshot (#11) stored worker-side
|
||||
/// next to the model slab. Scoped to the `ArchHandle` it was captured
|
||||
/// from — `Job::DropArch` drops every snapshot under its handle. The
|
||||
/// snapshot's tensors never leave the worker thread; the async side
|
||||
/// holds only this id plus the token sequence it covers.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct KvSnapshotId(pub u64);
|
||||
|
||||
/// One image payload for `Job::ForwardLogitsWithImages` /
|
||||
/// `Job::EncodeImage`. Pixels are row-major `(c, h, w)` f32 — the
|
||||
/// shape `harness::preprocess::preprocess` produces. Carries the
|
||||
/// shape inline since `Vec<f32>` is rank-1.
|
||||
///
|
||||
/// `Clone` so the vision-aware dispatch in `chat_completion` can
|
||||
/// match `&vision_route` (carrying borrowed images) and still hand
|
||||
/// owned `Vec<ImageInput>` to the worker job. The clone cost is one
|
||||
/// pixel-buffer memcpy per image — now variable with dynamic resolution
|
||||
/// (#14): `3 × h × w × 4` bytes, up to ~6.3 MiB at the default 1024²
|
||||
/// `max_pixels` budget.
|
||||
///
|
||||
/// `h`/`w` are the **resized** dims (factor-aligned), so the per-image LM
|
||||
/// grid is `(h/factor, w/factor)` — derived downstream for the splice
|
||||
/// and the interleaved-M-RoPE position ids.
|
||||
#[derive(Clone)]
|
||||
pub struct ImageInput {
|
||||
pub pixels: Vec<f32>,
|
||||
pub c: usize,
|
||||
pub h: usize,
|
||||
pub w: usize,
|
||||
}
|
||||
|
||||
/// One unit of work for the device worker.
|
||||
///
|
||||
/// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the
|
||||
/// single-GPU inference primitives: transfer-in a freshly-loaded
|
||||
/// `ModelArch`, drop it, clear its KV cache, and run one forward step
|
||||
/// returning CPU-side logits ready for sampling on the async caller.
|
||||
///
|
||||
/// Sampling stays on the async side intentionally. The worker copies
|
||||
/// logits to CPU (`Vec<f32>`) before reply, so the device-resident
|
||||
/// tensor never escapes the worker thread and the async caller's
|
||||
/// `LogitsProcessor::sample` runs entirely on the CPU candle backend
|
||||
/// — no incidental context binding on a tokio worker thread.
|
||||
pub enum Job {
|
||||
/// Query free / total VRAM on the device. Returns
|
||||
/// `(free_mb, total_mb)`. CPU builds and contexts that failed to
|
||||
/// initialise reply with `(0, 0)` — matches today's
|
||||
/// `device_vram_mb` sentinel so the log field values don't change.
|
||||
QueryVram {
|
||||
reply: oneshot::Sender<Result<(u64, u64)>>,
|
||||
},
|
||||
/// Load a GGUF (pre-quantized) single-GPU model on the worker
|
||||
/// thread. The dispatch handler opens the GGUF file, parses
|
||||
/// metadata, dispatches on `general.architecture`, and inserts
|
||||
/// the resulting `ModelArch` into the slab. Returns the fresh
|
||||
/// `ArchHandle`.
|
||||
LoadGguf {
|
||||
gguf_path: PathBuf,
|
||||
model_id: String,
|
||||
reply: oneshot::Sender<Result<ArchHandle>>,
|
||||
},
|
||||
/// Load a dense safetensors single-GPU model on the worker
|
||||
/// thread. The dispatch handler reads `config.json`, dispatches on
|
||||
/// `model_type`, builds a `VarBuilder` over the mmap'd
|
||||
/// safetensors, and inserts the resulting `ModelArch`.
|
||||
LoadDense {
|
||||
config_path: PathBuf,
|
||||
safetensors_paths: Vec<PathBuf>,
|
||||
model_id: String,
|
||||
reply: oneshot::Sender<Result<ArchHandle>>,
|
||||
},
|
||||
/// Remove the model from the slab and drop it. The `Drop` runs on
|
||||
/// the worker thread so CUDA tensors release their memory on the
|
||||
/// same context that allocated them.
|
||||
DropArch {
|
||||
handle: ArchHandle,
|
||||
reply: oneshot::Sender<()>,
|
||||
},
|
||||
/// Reset the KV cache for this model. Called at the start of every
|
||||
/// chat completion so a new request doesn't attend over the
|
||||
/// previous one's tokens.
|
||||
ClearKv {
|
||||
handle: ArchHandle,
|
||||
reply: oneshot::Sender<Result<()>>,
|
||||
},
|
||||
/// Capture the model's live cache state (attention KV + GDN
|
||||
/// recurrent state + position counters) as a prefix snapshot
|
||||
/// (#11). The snapshot stays in the worker's state, keyed by the
|
||||
/// returned id; the reply carries `(id, bytes)` so the async side
|
||||
/// can do budget accounting without touching tensors. Errors on
|
||||
/// archs without snapshot support.
|
||||
SnapshotKv {
|
||||
handle: ArchHandle,
|
||||
reply: oneshot::Sender<Result<(KvSnapshotId, u64)>>,
|
||||
},
|
||||
/// Replace the model's live cache state with a stored snapshot,
|
||||
/// instead of `ClearKv`, so prefill can resume at the snapshot's
|
||||
/// token boundary. The snapshot remains stored (restorable again).
|
||||
RestoreKv {
|
||||
handle: ArchHandle,
|
||||
snapshot: KvSnapshotId,
|
||||
reply: oneshot::Sender<Result<()>>,
|
||||
},
|
||||
/// Drop one stored snapshot (prefix-cache eviction). Idempotent.
|
||||
DropKvSnapshot {
|
||||
handle: ArchHandle,
|
||||
snapshot: KvSnapshotId,
|
||||
reply: oneshot::Sender<()>,
|
||||
},
|
||||
/// Run one forward step and copy the resulting `[vocab]` logits to
|
||||
/// CPU. The caller takes the returned `Vec<f32>`, wraps it in a
|
||||
/// CPU `Tensor`, and runs `apply_repeat_penalty` + sampling
|
||||
/// without touching the device context. `offset` is the KV-cache
|
||||
/// position before this step (0 for prefill, `prompt_len + i` for
|
||||
/// the i-th decode step).
|
||||
ForwardLogits {
|
||||
handle: ArchHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||
},
|
||||
/// Run the LM forward with vision splicing in one round-trip.
|
||||
/// Stage B3 of the vision plan.
|
||||
///
|
||||
/// Inputs:
|
||||
/// - `tokens`: prompt-expanded token ids (the caller has already
|
||||
/// replaced each `<|image_pad|>` with N copies per the
|
||||
/// per-image patch count, so `tokens` already contains exactly
|
||||
/// `sum(n_i)` `image_token_id` entries across all images).
|
||||
/// - `offset`: KV-cache position (same contract as `ForwardLogits`).
|
||||
/// - `images`: one entry per image — preprocessed pixels plus the
|
||||
/// `(c, h, w)` shape. Images are encoded on the worker via the
|
||||
/// model's vision tower (`VisionTower::forward`), concatenated
|
||||
/// in order, and spliced into the LM input embeddings at
|
||||
/// `image_token_id` positions.
|
||||
/// - `image_token_id`: the sentinel token (248056 for Qwen3.6).
|
||||
///
|
||||
/// Returns flat CPU `[vocab]` logits, same as `ForwardLogits`.
|
||||
/// Image embeddings stay device-resident — they're never copied
|
||||
/// to CPU. The "tensors don't escape the worker" invariant holds.
|
||||
ForwardLogitsWithImages {
|
||||
handle: ArchHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
images: Vec<ImageInput>,
|
||||
image_token_id: u32,
|
||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||
},
|
||||
/// Encode one image through the model's vision tower. Stage A5 of
|
||||
/// the vision plan (`doc/vision-qwen3_6-spec.md`).
|
||||
///
|
||||
/// `pixels` is the CPU-side preprocessed image tensor in row-major
|
||||
/// `(C, H, W)` f32 layout — what `harness::preprocess::preprocess`
|
||||
/// produces. `c`, `h`, `w` carry the shape since `Vec<f32>` itself
|
||||
/// is rank-1. The handler reconstructs the tensor on the worker's
|
||||
/// device, runs `VisionTower::forward`, copies the resulting
|
||||
/// `(N_lm_tokens, hidden_size)` embedding back to CPU as a flat
|
||||
/// `Vec<f32>` (the caller knows the expected shape from
|
||||
/// `VisionTower::lm_tokens_for(h, w) * hidden_size`).
|
||||
///
|
||||
/// Mirrors the `ForwardLogits` "tensors don't escape" invariant —
|
||||
/// device-side image embeddings are dropped at handler return.
|
||||
/// Stage B will introduce a follow-up variant that keeps the
|
||||
/// embeddings device-resident and references them from the next
|
||||
/// `ForwardLogits` call, avoiding the round-trip copy.
|
||||
EncodeImage {
|
||||
handle: ArchHandle,
|
||||
pixels: Vec<f32>,
|
||||
c: usize,
|
||||
h: usize,
|
||||
w: usize,
|
||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||
},
|
||||
/// Initialize the leader's NCCL communicator. The worker's
|
||||
/// `NcclState` mints the `Comm` here so its underlying
|
||||
/// `ncclComm_t` and `CudaContext` live on the same thread as
|
||||
/// every later `Comm::all_reduce` call. Reply is the worker
|
||||
/// response shape used by the subprocess workers (`InitOk` on
|
||||
/// success, `Error` on failure) so the calling
|
||||
/// `WorkerPool::init_nccl` orchestration stays uniform.
|
||||
///
|
||||
/// Available on both cuda and no-cuda builds — the dispatch
|
||||
/// handler calls `NcclState::init` which has a no-cuda stub that
|
||||
/// replies with `cuda_feature_not_enabled`. Keeping the Job
|
||||
/// variant ungated lets `WorkerPool::init_nccl` stay uniform.
|
||||
NcclInit {
|
||||
cfg: crate::harness::tp::worker::WorkerConfig,
|
||||
comm_id_hex: String,
|
||||
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
|
||||
},
|
||||
/// Run NCCL's all_reduce sanity check on the leader's rank 0.
|
||||
/// Same response shape as `NcclInit`; also available on both
|
||||
/// builds via the no-cuda `NcclState::sanity_check` stub.
|
||||
NcclSanity {
|
||||
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
|
||||
},
|
||||
/// Hand a clonable handle to the leader's NCCL `Comm` back to the
|
||||
/// async side, so the TP step watchdog can call `ncclCommAbort` on
|
||||
/// it from a *different* thread to unblock a wedged collective
|
||||
/// (#17 Stage 2). Fetched once at init while the worker thread is
|
||||
/// still responsive — a thread already wedged in a collective can't
|
||||
/// service this job, which is exactly why the handle is cached
|
||||
/// up front. Replies `None` before `NcclInit` has run.
|
||||
#[cfg(feature = "cuda")]
|
||||
GetLeaderComm {
|
||||
reply: oneshot::Sender<Option<crate::harness::tp::nccl_state::SendComm>>,
|
||||
},
|
||||
/// Load the leader's TP shard on the worker thread. The dispatch
|
||||
/// handler reads `state.nccl.comm()` directly (no cross-thread
|
||||
/// `Arc<Comm>` transfer, no `SendComm` wrapper) and builds the
|
||||
/// `TpLeaderModel` against that Comm. The model's embedded
|
||||
/// `Arc<Comm>` clones, `CudaContext`, and all per-rank CUDA
|
||||
/// tensors live on this thread for the model's lifetime.
|
||||
/// Inserts into the TP slab and returns the fresh `TpHandle`.
|
||||
#[cfg(feature = "cuda")]
|
||||
TpLoadShard {
|
||||
model_id: String,
|
||||
config_json: String,
|
||||
safetensors_paths: Vec<PathBuf>,
|
||||
dtype: candle_core::DType,
|
||||
quant: Option<String>,
|
||||
world_size: u32,
|
||||
reply: oneshot::Sender<Result<TpHandle>>,
|
||||
},
|
||||
/// Drop the TP leader model on the worker thread. CUDA tensors
|
||||
/// and `Arc<Comm>` clones held inside the model release on the
|
||||
/// thread that allocated them.
|
||||
#[cfg(feature = "cuda")]
|
||||
DropTp {
|
||||
handle: TpHandle,
|
||||
reply: oneshot::Sender<()>,
|
||||
},
|
||||
/// Reset the leader's KV cache for a TP model. Mirrors `ClearKv`
|
||||
/// for single-GPU.
|
||||
#[cfg(feature = "cuda")]
|
||||
TpClearKv {
|
||||
handle: TpHandle,
|
||||
reply: oneshot::Sender<Result<()>>,
|
||||
},
|
||||
/// Capture the leader's TP cache state as a prefix snapshot (#11),
|
||||
/// stored worker-side under the pool-minted `snapshot_id` (shared
|
||||
/// with the subprocess ranks, so all ranks key the same snapshot
|
||||
/// identically). Replies with the leader shard's snapshot bytes.
|
||||
#[cfg(feature = "cuda")]
|
||||
TpSnapshotKv {
|
||||
handle: TpHandle,
|
||||
snapshot_id: u64,
|
||||
reply: oneshot::Sender<Result<u64>>,
|
||||
},
|
||||
/// Replace the leader's live TP cache state with a stored
|
||||
/// snapshot. Mirrors `RestoreKv` for single-GPU.
|
||||
#[cfg(feature = "cuda")]
|
||||
TpRestoreKv {
|
||||
handle: TpHandle,
|
||||
snapshot_id: u64,
|
||||
reply: oneshot::Sender<Result<()>>,
|
||||
},
|
||||
/// Drop one stored leader TP snapshot (eviction). Idempotent.
|
||||
#[cfg(feature = "cuda")]
|
||||
TpDropKvSnapshot {
|
||||
handle: TpHandle,
|
||||
snapshot_id: u64,
|
||||
reply: oneshot::Sender<()>,
|
||||
},
|
||||
/// Run one TP forward step on the leader's shard. Returns CPU-
|
||||
/// side logits as a `Vec<f32>` so the async caller can sample
|
||||
/// without holding a device tensor. The caller is also
|
||||
/// responsible for fan-out to subprocess ranks and drain — only
|
||||
/// the leader's forward moves into the worker thread.
|
||||
#[cfg(feature = "cuda")]
|
||||
TpForwardLogits {
|
||||
handle: TpHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||
},
|
||||
/// Image-bearing leader (rank 0) forward for the single-shot vision
|
||||
/// prefill. The handler preprocesses each `image_data_uris` entry
|
||||
/// (the same deterministic path every rank runs), encodes through
|
||||
/// the leader's replicated tower, splices at `image_token_id`, and
|
||||
/// returns CPU-side `[vocab]` logits. Image tensors never escape the
|
||||
/// worker thread. Caller fans out `GenerateStepWithImages` to the
|
||||
/// subprocess ranks and drains them; only the leader forward moves
|
||||
/// here.
|
||||
#[cfg(feature = "cuda")]
|
||||
TpForwardLogitsWithImages {
|
||||
handle: TpHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
image_token_id: u32,
|
||||
image_data_uris: Vec<String>,
|
||||
chunk_size: usize,
|
||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||
},
|
||||
/// Tell the worker to break its dispatch loop and exit. Any jobs
|
||||
/// queued after this in the channel reply `Err` to their oneshot
|
||||
/// senders (the senders are dropped on the worker's exit, which
|
||||
/// the async-side `Receiver::await` maps to `WorkerError::Gone`).
|
||||
Shutdown,
|
||||
}
|
||||
@@ -1,948 +0,0 @@
|
||||
//! Per-device CUDA worker thread.
|
||||
//!
|
||||
//! One dedicated OS thread per CUDA device the leader uses. The thread
|
||||
//! binds the device's `CudaContext` once at startup and owns it for the
|
||||
//! daemon's lifetime; all GPU operations and VRAM queries for that
|
||||
//! device route through a `std::sync::mpsc` channel into this thread.
|
||||
//! Tensors never escape the thread alive — replies cross the channel
|
||||
//! as plain values (`u32` tokens, `(u64, u64)` mb numbers, `()`).
|
||||
//!
|
||||
//! Rationale, in order of weight:
|
||||
//!
|
||||
//! 1. **Context locality.** cudarc binds the CUDA context per OS thread
|
||||
//! via `cuCtxSetCurrent`. With `tokio::task::spawn_blocking`, the
|
||||
//! blocking thread chosen is arbitrary, so the context gets bound
|
||||
//! onto a different thread each time and `device_vram_mb()` from an
|
||||
//! async task binds it again on the *caller's* thread as a side
|
||||
//! effect. Pinning the context to one named thread ends that.
|
||||
//!
|
||||
//! 2. **Drop safety.** `cudarc::driver::CudaContext`, every `CudaSlice`
|
||||
//! inside a `Tensor`, and every `cudarc::nccl::Comm` call `cuMemFree`
|
||||
//! / `cuCtxDestroy` / `ncclCommDestroy` during `Drop`. These must
|
||||
//! run with the right context current. Owning everything in this
|
||||
//! thread's state slab and dropping it via `Job::DropArch` /
|
||||
//! `Job::Shutdown` is the only safe pattern.
|
||||
//!
|
||||
//! 3. **Poisoning blast radius.** When a CUDA driver error (illegal
|
||||
//! address, OOM cascade) makes the context unrecoverable, today the
|
||||
//! spawn_blocking thread carrying that bad state simply returns to
|
||||
//! tokio's pool — invisible. With the per-device thread, the
|
||||
//! poisoned flag lives on the thread itself; subsequent
|
||||
//! `submit()` calls fast-reject at the channel boundary with a
|
||||
//! clear "device worker is poisoned" error before any further CUDA
|
||||
//! work is attempted.
|
||||
//!
|
||||
//! The TP worker subprocesses (`harness/tp/worker.rs`) are already this
|
||||
//! pattern, just out-of-process. The in-process variant uses the same
|
||||
//! discipline for rank 0.
|
||||
//!
|
||||
//! Phase 1 of the refactor exposes only `Job::QueryVram` + `Job::Shutdown`.
|
||||
//! Forward, kv-cache clear, model load, and NCCL bring-up move in later
|
||||
//! phases. See `/home/grenade/.claude/plans/plan-the-per-device-worker-abstract-micali.md`.
|
||||
|
||||
pub mod dispatch;
|
||||
pub mod jobs;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::mpsc::{self, Sender};
|
||||
use std::thread::JoinHandle;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
pub use jobs::TpHandle;
|
||||
pub use jobs::{ArchHandle, Job, KvSnapshotId};
|
||||
|
||||
/// Errors returned by `DeviceWorkerHandle` submit methods.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum WorkerError {
|
||||
/// The worker's CUDA context was poisoned by an earlier driver
|
||||
/// error. The thread is still alive (dropping it would re-touch
|
||||
/// the broken context); it returns this error for every job
|
||||
/// submitted until the daemon is restarted.
|
||||
#[error(
|
||||
"device worker for device {device_index} is poisoned \
|
||||
(a prior CUDA driver error left the context unrecoverable); \
|
||||
restart the daemon to recover"
|
||||
)]
|
||||
Poisoned { device_index: u32 },
|
||||
/// The worker thread has exited (`Job::Shutdown` was processed or
|
||||
/// the thread panicked). Subsequent `submit()` calls fail here
|
||||
/// rather than blocking forever.
|
||||
#[error("device worker for device {device_index} is no longer running")]
|
||||
Gone { device_index: u32 },
|
||||
/// The dispatched job returned an `Err`. Forwarded verbatim.
|
||||
#[error(transparent)]
|
||||
Job(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
/// Shared handle to a per-device CUDA worker thread.
|
||||
///
|
||||
/// Cloning the `Arc` lets multiple `LoadedModel`s (and `TpLoadedModel`s)
|
||||
/// share the same worker — there's one worker per CUDA device index,
|
||||
/// not one per model.
|
||||
pub struct DeviceWorkerHandle {
|
||||
device_index: u32,
|
||||
tx: Sender<Job>,
|
||||
poisoned: Arc<AtomicBool>,
|
||||
/// `Mutex<Option<JoinHandle>>` so `shutdown()` can take the handle
|
||||
/// out without `&mut self` and so the inevitable `Drop` after
|
||||
/// `shutdown()` doesn't double-join. The mutex is uncontended in
|
||||
/// practice: only one caller ever takes the handle.
|
||||
join: std::sync::Mutex<Option<JoinHandle<()>>>,
|
||||
}
|
||||
|
||||
impl DeviceWorkerHandle {
|
||||
/// Spawn a new worker for the given CUDA device index.
|
||||
///
|
||||
/// The thread is named `cuda-dev-N` so it shows up legibly in
|
||||
/// `top -H`, `pidstat -t`, and gdb backtraces. On CUDA builds, the
|
||||
/// thread binds `CudaContext::new(N)` on startup; on CPU builds
|
||||
/// (`--no-default-features`) the thread runs without a context and
|
||||
/// every job that touches CUDA falls through to a zero return.
|
||||
pub fn spawn(device_index: u32) -> anyhow::Result<Arc<Self>> {
|
||||
let (tx, rx) = mpsc::channel::<Job>();
|
||||
let poisoned = Arc::new(AtomicBool::new(false));
|
||||
let poisoned_for_thread = Arc::clone(&poisoned);
|
||||
let join = std::thread::Builder::new()
|
||||
.name(format!("cuda-dev-{device_index}"))
|
||||
.spawn(move || {
|
||||
dispatch::run(device_index, rx, poisoned_for_thread);
|
||||
})?;
|
||||
Ok(Arc::new(Self {
|
||||
device_index,
|
||||
tx,
|
||||
poisoned,
|
||||
join: std::sync::Mutex::new(Some(join)),
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn device_index(&self) -> u32 {
|
||||
self.device_index
|
||||
}
|
||||
|
||||
pub fn is_poisoned(&self) -> bool {
|
||||
self.poisoned.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
/// Mark the worker's context as poisoned. Future `submit()` calls
|
||||
/// short-circuit to `WorkerError::Poisoned` before sending. The
|
||||
/// dispatch loop also flips into drain-only mode when it sees this
|
||||
/// flag, so any jobs already in flight on the channel reply with
|
||||
/// the same error without touching CUDA.
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn set_poisoned(&self) {
|
||||
self.poisoned.store(true, Ordering::Release);
|
||||
}
|
||||
|
||||
/// Send `Job::QueryVram`, await the worker's reply.
|
||||
///
|
||||
/// Returns `Ok((free_mb, total_mb))` on success, `Ok((0, 0))` on
|
||||
/// CPU builds or when the device lacks a bound context, or an
|
||||
/// error if the worker is poisoned, gone, or the query itself
|
||||
/// failed inside cudarc.
|
||||
pub async fn query_vram(&self) -> Result<(u64, u64), WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::QueryVram { reply: reply_tx })
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch a clonable handle to the leader's NCCL `Comm` (#17 Stage 2).
|
||||
/// The TP step watchdog caches this at init so it can call
|
||||
/// `ncclCommAbort` from the async thread to unblock a wedged
|
||||
/// collective. Returns `None` if uninitialised, poisoned, or gone —
|
||||
/// the caller treats a missing handle as "can't abort" and logs it.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub async fn get_leader_comm(&self) -> Option<crate::harness::tp::nccl_state::SendComm> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return None;
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
if self
|
||||
.tx
|
||||
.send(Job::GetLeaderComm { reply: reply_tx })
|
||||
.is_err()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
reply_rx.await.ok().flatten()
|
||||
}
|
||||
|
||||
/// Load a GGUF (pre-quantized) single-GPU model on the worker
|
||||
/// thread. The hf-hub resolution happens on the async caller; the
|
||||
/// resolved local `gguf_path` plus the spec's model_id are sent
|
||||
/// into the worker which opens, parses, and constructs the
|
||||
/// `ModelArch` on the right thread.
|
||||
pub async fn load_gguf(
|
||||
&self,
|
||||
gguf_path: std::path::PathBuf,
|
||||
model_id: String,
|
||||
) -> Result<ArchHandle, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::LoadGguf {
|
||||
gguf_path,
|
||||
model_id,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load a dense safetensors single-GPU model on the worker thread.
|
||||
pub async fn load_dense(
|
||||
&self,
|
||||
config_path: std::path::PathBuf,
|
||||
safetensors_paths: Vec<std::path::PathBuf>,
|
||||
model_id: String,
|
||||
) -> Result<ArchHandle, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::LoadDense {
|
||||
config_path,
|
||||
safetensors_paths,
|
||||
model_id,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Tell the worker to drop the `ModelArch` for `handle` on the
|
||||
/// worker thread (so CUDA tensors release on the right context).
|
||||
/// Returns `Ok(())` even if the handle wasn't in the slab — Drop
|
||||
/// is idempotent. Reports `Gone` if the worker isn't running.
|
||||
pub async fn drop_arch(&self, handle: ArchHandle) -> Result<(), WorkerError> {
|
||||
// Poisoning doesn't block DropArch — even on a poisoned
|
||||
// context we want callers to unblock and proceed with the
|
||||
// unload bookkeeping. The dispatch handler under poison just
|
||||
// replies `()` without touching the model (the actual Drop
|
||||
// happens via mem::forget at thread exit per the poison
|
||||
// protocol).
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::DropArch {
|
||||
handle,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the KV cache for the model at `handle`. Called at the
|
||||
/// start of every chat completion so the new prompt doesn't
|
||||
/// attend over the previous request's tokens.
|
||||
pub async fn clear_kv_cache(&self, handle: ArchHandle) -> Result<(), WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::ClearKv {
|
||||
handle,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Capture the model's live cache state as a worker-side prefix
|
||||
/// snapshot (#11). Returns the snapshot id plus its byte size for
|
||||
/// the async-side budget accounting. Tensors stay on the worker.
|
||||
pub async fn snapshot_kv(
|
||||
&self,
|
||||
handle: ArchHandle,
|
||||
) -> Result<(jobs::KvSnapshotId, u64), WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::SnapshotKv {
|
||||
handle,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace the model's live cache state with a stored snapshot —
|
||||
/// called instead of [`Self::clear_kv_cache`] on a prefix-cache
|
||||
/// hit. The snapshot remains stored and restorable again.
|
||||
pub async fn restore_kv(
|
||||
&self,
|
||||
handle: ArchHandle,
|
||||
snapshot: jobs::KvSnapshotId,
|
||||
) -> Result<(), WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::RestoreKv {
|
||||
handle,
|
||||
snapshot,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop one stored prefix snapshot (eviction). Mirrors
|
||||
/// [`Self::drop_arch`]'s poison-tolerant unit-reply shape so
|
||||
/// bookkeeping always unblocks.
|
||||
pub async fn drop_kv_snapshot(
|
||||
&self,
|
||||
handle: ArchHandle,
|
||||
snapshot: jobs::KvSnapshotId,
|
||||
) -> Result<(), WorkerError> {
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::DropKvSnapshot {
|
||||
handle,
|
||||
snapshot,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run one forward step and return the resulting `[vocab]` logits
|
||||
/// as a CPU-side `Vec<f32>`. The caller then samples on a CPU
|
||||
/// candle Tensor without ever binding the device context on its
|
||||
/// tokio thread.
|
||||
pub async fn forward_logits(
|
||||
&self,
|
||||
handle: ArchHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
) -> Result<Vec<f32>, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::ForwardLogits {
|
||||
handle,
|
||||
tokens,
|
||||
offset,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward with image-aware splicing in one round-trip. Stage B3.
|
||||
///
|
||||
/// Encodes each image on the worker thread (device-resident), then
|
||||
/// runs the LM forward with the embeddings spliced at
|
||||
/// `image_token_id` positions. Returns CPU `[vocab]` logits, same
|
||||
/// shape as `forward_logits`. Image embeddings never copy back to
|
||||
/// CPU.
|
||||
pub async fn forward_logits_with_images(
|
||||
&self,
|
||||
handle: ArchHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
images: Vec<crate::harness::device_worker::jobs::ImageInput>,
|
||||
image_token_id: u32,
|
||||
) -> Result<Vec<f32>, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::ForwardLogitsWithImages {
|
||||
handle,
|
||||
tokens,
|
||||
offset,
|
||||
images,
|
||||
image_token_id,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode a preprocessed image through the model's vision tower
|
||||
/// and return the resulting LM-side image embeddings as a
|
||||
/// flattened CPU `Vec<f32>`. Stage A5.
|
||||
///
|
||||
/// `pixels` is the row-major `(c, h, w)` f32 image —
|
||||
/// `harness::preprocess::preprocess` produces this exact shape.
|
||||
/// The caller knows the expected output length from
|
||||
/// `VisionTower::lm_tokens_for(h, w) * hidden_size` and reshapes
|
||||
/// accordingly.
|
||||
pub async fn encode_image(
|
||||
&self,
|
||||
handle: ArchHandle,
|
||||
pixels: Vec<f32>,
|
||||
c: usize,
|
||||
h: usize,
|
||||
w: usize,
|
||||
) -> Result<Vec<f32>, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::EncodeImage {
|
||||
handle,
|
||||
pixels,
|
||||
c,
|
||||
h,
|
||||
w,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialise the leader's NCCL communicator. The reply uses
|
||||
/// `WorkerResponse` (same shape subprocess workers use over stdio
|
||||
/// RPC) so `WorkerPool::init_nccl`'s aggregation treats leader +
|
||||
/// subprocess responses uniformly. Available on no-cuda builds
|
||||
/// too — the dispatch handler calls the no-cuda `NcclState::init`
|
||||
/// stub which replies `cuda_feature_not_enabled`.
|
||||
pub async fn nccl_init(
|
||||
&self,
|
||||
cfg: crate::harness::tp::worker::WorkerConfig,
|
||||
comm_id_hex: String,
|
||||
) -> Result<crate::harness::tp::rpc::WorkerResponse, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::NcclInit {
|
||||
cfg,
|
||||
comm_id_hex,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
reply_rx.await.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})
|
||||
}
|
||||
|
||||
/// Run an NCCL sanity all_reduce on the leader's rank 0.
|
||||
/// Available on no-cuda builds; replies with an error response.
|
||||
pub async fn nccl_sanity(
|
||||
&self,
|
||||
) -> Result<crate::harness::tp::rpc::WorkerResponse, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::NcclSanity { reply: reply_tx })
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
reply_rx.await.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})
|
||||
}
|
||||
|
||||
/// Load the leader's TP shard on the worker thread. The dispatch
|
||||
/// handler reads its own `NcclState`'s `Arc<Comm>` directly — no
|
||||
/// cross-thread Comm transfer — and builds the `TpLeaderModel`
|
||||
/// against it. Phase 4 replaces the Phase 3 Clone/TransferIn
|
||||
/// bridge with this single Job.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn tp_load_shard(
|
||||
&self,
|
||||
model_id: String,
|
||||
config_json: String,
|
||||
safetensors_paths: Vec<std::path::PathBuf>,
|
||||
dtype: candle_core::DType,
|
||||
quant: Option<String>,
|
||||
world_size: u32,
|
||||
) -> Result<TpHandle, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::TpLoadShard {
|
||||
model_id,
|
||||
config_json,
|
||||
safetensors_paths,
|
||||
dtype,
|
||||
quant,
|
||||
world_size,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop the TP model at `handle` on the worker thread.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub async fn drop_tp(&self, handle: TpHandle) -> Result<(), WorkerError> {
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::DropTp {
|
||||
handle,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the leader's KV cache for a TP model.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub async fn tp_clear_kv(&self, handle: TpHandle) -> Result<(), WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::TpClearKv {
|
||||
handle,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Capture the leader's TP cache state as a prefix snapshot (#11)
|
||||
/// stored under the pool-minted `snapshot_id`. Returns the leader
|
||||
/// shard's snapshot bytes.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub async fn tp_snapshot_kv(
|
||||
&self,
|
||||
handle: TpHandle,
|
||||
snapshot_id: u64,
|
||||
) -> Result<u64, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::TpSnapshotKv {
|
||||
handle,
|
||||
snapshot_id,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace the leader's live TP cache state with a stored
|
||||
/// snapshot — called instead of [`Self::tp_clear_kv`] on a
|
||||
/// prefix-cache hit.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub async fn tp_restore_kv(
|
||||
&self,
|
||||
handle: TpHandle,
|
||||
snapshot_id: u64,
|
||||
) -> Result<(), WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::TpRestoreKv {
|
||||
handle,
|
||||
snapshot_id,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop one stored leader TP snapshot (eviction). Poison-tolerant
|
||||
/// unit reply, same shape as [`Self::drop_kv_snapshot`].
|
||||
#[cfg(feature = "cuda")]
|
||||
pub async fn tp_drop_kv_snapshot(
|
||||
&self,
|
||||
handle: TpHandle,
|
||||
snapshot_id: u64,
|
||||
) -> Result<(), WorkerError> {
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::TpDropKvSnapshot {
|
||||
handle,
|
||||
snapshot_id,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run one TP forward step on the leader's shard. Returns CPU-side
|
||||
/// logits as `Vec<f32>` ready for sampling. The caller is
|
||||
/// responsible for fan-out / drain of the subprocess workers
|
||||
/// concurrently with this call.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub async fn tp_forward_logits(
|
||||
&self,
|
||||
handle: TpHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
) -> Result<Vec<f32>, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::TpForwardLogits {
|
||||
handle,
|
||||
tokens,
|
||||
offset,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Image-bearing TP leader forward (single-shot vision prefill).
|
||||
/// Routes `Job::TpForwardLogitsWithImages` onto the worker thread;
|
||||
/// the handler preprocesses + encodes + splices + forwards and
|
||||
/// returns CPU-side `[vocab]` logits. The `WorkerPool` fans the
|
||||
/// matching `GenerateStepWithImages` out to subprocess ranks so the
|
||||
/// row-parallel collectives complete.
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn tp_forward_logits_with_images(
|
||||
&self,
|
||||
handle: TpHandle,
|
||||
tokens: Vec<u32>,
|
||||
offset: usize,
|
||||
image_token_id: u32,
|
||||
image_data_uris: Vec<String>,
|
||||
chunk_size: usize,
|
||||
) -> Result<Vec<f32>, WorkerError> {
|
||||
if self.poisoned.load(Ordering::Acquire) {
|
||||
return Err(WorkerError::Poisoned {
|
||||
device_index: self.device_index,
|
||||
});
|
||||
}
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(Job::TpForwardLogitsWithImages {
|
||||
handle,
|
||||
tokens,
|
||||
offset,
|
||||
image_token_id,
|
||||
image_data_uris,
|
||||
chunk_size,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.map_err(|_| WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
})?;
|
||||
match reply_rx.await {
|
||||
Ok(result) => result.map_err(WorkerError::from),
|
||||
Err(_) => Err(WorkerError::Gone {
|
||||
device_index: self.device_index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send `Job::Shutdown` and join the thread. Idempotent — calling
|
||||
/// twice is a no-op the second time.
|
||||
pub fn shutdown(&self) -> anyhow::Result<()> {
|
||||
// Best-effort send: if the channel is already closed (thread
|
||||
// exited after a prior shutdown or panic) the send fails and
|
||||
// we fall through to the join which returns the panic, if any.
|
||||
let _ = self.tx.send(Job::Shutdown);
|
||||
let join = self.join.lock().unwrap().take();
|
||||
if let Some(j) = join {
|
||||
j.join()
|
||||
.map_err(|_| anyhow::anyhow!("worker thread panicked during shutdown"))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for DeviceWorkerHandle {
|
||||
fn drop(&mut self) {
|
||||
// Best-effort: send Shutdown so the thread breaks its loop
|
||||
// and exits. We do NOT join here — Drop may run on a tokio
|
||||
// worker thread, and joining a thread that's still processing
|
||||
// the last job would block the runtime. The OS reaps the
|
||||
// thread on detach.
|
||||
let _ = self.tx.send(Job::Shutdown);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
#[tokio::test]
|
||||
async fn spawn_query_vram_shutdown() {
|
||||
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||
// CPU build (the only one CI runs) returns (0, 0) by design;
|
||||
// a CUDA build with a real device would return real values.
|
||||
let result = handle.query_vram().await.expect("query ok");
|
||||
// We assert >= 0 — the field width matters more than the value.
|
||||
let _ = result.0;
|
||||
let _ = result.1;
|
||||
handle.shutdown().expect("shutdown ok");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_is_named_correctly() {
|
||||
// The thread name lets `top -H` / pidstat / gdb show
|
||||
// `cuda-dev-N` instead of an opaque tokio worker name. Verify
|
||||
// by spawning and reading proc-self thread comms — but on
|
||||
// platforms without /proc, just confirm we don't crash.
|
||||
let handle = DeviceWorkerHandle::spawn(7).expect("spawn ok");
|
||||
// Round-trip a job to ensure the thread is alive and processing.
|
||||
handle.query_vram().await.expect("query ok");
|
||||
handle.shutdown().expect("shutdown ok");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn submit_after_shutdown_returns_gone() {
|
||||
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||
handle.shutdown().expect("shutdown ok");
|
||||
// Channel closed; submit should map to Gone rather than block.
|
||||
let result = handle.query_vram().await;
|
||||
match result {
|
||||
Err(WorkerError::Gone { device_index: 0 }) => {}
|
||||
other => panic!("expected Gone, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poisoned_flag_short_circuits_submit() {
|
||||
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||
handle.set_poisoned();
|
||||
let result = handle.query_vram().await;
|
||||
match result {
|
||||
Err(WorkerError::Poisoned { device_index: 0 }) => {}
|
||||
other => panic!("expected Poisoned, got {other:?}"),
|
||||
}
|
||||
// The channel is still alive; shutdown should still succeed.
|
||||
handle.shutdown().expect("shutdown ok");
|
||||
}
|
||||
|
||||
/// Stage A5: confirm the EncodeImage job round-trips through the
|
||||
/// worker channel. We don't have a real loaded model in the slab
|
||||
/// here, so the dispatch handler returns the
|
||||
/// "no model for handle" error — which is exactly what we want to
|
||||
/// see, since it proves the message routed through the channel
|
||||
/// and reached the handler. Real-weights validation lives in the
|
||||
/// Stage A7 / Stage B post-deploy smoke on beast.
|
||||
#[tokio::test]
|
||||
async fn encode_image_routes_to_dispatch_and_errors_on_unknown_handle() {
|
||||
use crate::harness::device_worker::jobs::ArchHandle;
|
||||
|
||||
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||
let fake_arch = ArchHandle(99_999);
|
||||
// (3, 4, 4) fake image — minimal payload, gets reconstructed
|
||||
// on the worker before the handler errors out on the unknown
|
||||
// ArchHandle lookup.
|
||||
let pixels = vec![0.0_f32; 3 * 4 * 4];
|
||||
let result = handle.encode_image(fake_arch, pixels, 3, 4, 4).await;
|
||||
match result {
|
||||
Err(WorkerError::Job(e)) => {
|
||||
let msg = format!("{e:#}");
|
||||
assert!(
|
||||
msg.contains("EncodeImage: no model for handle"),
|
||||
"expected unknown-handle error, got: {msg}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected Job(Err), got {other:?}"),
|
||||
}
|
||||
handle.shutdown().expect("shutdown ok");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_drains_pending_jobs() {
|
||||
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||
// Submit many concurrent jobs; they should all complete even
|
||||
// though a Shutdown is racing them.
|
||||
let mut futures = Vec::new();
|
||||
for _ in 0..16 {
|
||||
let h = Arc::clone(&handle);
|
||||
futures.push(tokio::spawn(async move { h.query_vram().await }));
|
||||
}
|
||||
// Small yield to give the senders a chance to actually send
|
||||
// before we issue the shutdown; not strictly necessary because
|
||||
// the channel is FIFO, but makes the test's intent clearer.
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
handle.shutdown().expect("shutdown ok");
|
||||
for f in futures {
|
||||
// Each query should have completed (Ok or Gone, never panic).
|
||||
let _ = f.await.expect("task did not panic");
|
||||
}
|
||||
}
|
||||
}
|
||||
1
crates/neuron/src/harness/llamacpp.rs
Normal file
1
crates/neuron/src/harness/llamacpp.rs
Normal file
@@ -0,0 +1 @@
|
||||
// llama.cpp harness implementation — Phase 11.
|
||||
163
crates/neuron/src/harness/mistralrs.rs
Normal file
163
crates/neuron/src/harness/mistralrs.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
//! mistral.rs harness implementation.
|
||||
//!
|
||||
//! Wraps the mistral.rs HTTP API for model lifecycle management
|
||||
//! and optionally manages the process via systemd.
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use cortex_core::harness::{Harness, HarnessConfig, HarnessHealth, ModelInfo, ModelSpec};
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
|
||||
pub struct MistralRsHarness {
|
||||
endpoint: String,
|
||||
systemd_unit: Option<String>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl MistralRsHarness {
|
||||
pub fn new(endpoint: String, systemd_unit: Option<String>) -> Self {
|
||||
Self {
|
||||
endpoint,
|
||||
systemd_unit,
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("failed to build HTTP client"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Response from mistral.rs `GET /v1/models`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ModelsResponse {
|
||||
data: Vec<ModelEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ModelEntry {
|
||||
id: String,
|
||||
#[serde(default)]
|
||||
status: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Harness for MistralRsHarness {
|
||||
fn name(&self) -> &str {
|
||||
"mistralrs"
|
||||
}
|
||||
|
||||
async fn start(&self, _config: &HarnessConfig) -> Result<()> {
|
||||
let Some(unit) = &self.systemd_unit else {
|
||||
anyhow::bail!("no systemd unit configured for mistralrs harness");
|
||||
};
|
||||
|
||||
let output = tokio::process::Command::new("systemctl")
|
||||
.args(["start", unit])
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("systemctl start {unit} failed: {stderr}");
|
||||
}
|
||||
|
||||
// Wait for the health endpoint to respond (up to 30s).
|
||||
let url = format!("{}/health", self.endpoint);
|
||||
for _ in 0..30 {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
if self.client.get(&url).send().await.is_ok() {
|
||||
tracing::info!(unit, "mistralrs started and healthy");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
anyhow::bail!("mistralrs started but health endpoint did not respond within 30s");
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<()> {
|
||||
let Some(unit) = &self.systemd_unit else {
|
||||
anyhow::bail!("no systemd unit configured for mistralrs harness");
|
||||
};
|
||||
|
||||
let output = tokio::process::Command::new("systemctl")
|
||||
.args(["stop", unit])
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("systemctl stop {unit} failed: {stderr}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn health(&self) -> HarnessHealth {
|
||||
let url = format!("{}/health", self.endpoint);
|
||||
let running = self.client.get(&url).send().await.is_ok();
|
||||
HarnessHealth {
|
||||
name: "mistralrs".into(),
|
||||
running,
|
||||
uptime_secs: None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
let url = format!("{}/v1/models", self.endpoint);
|
||||
let resp = self.client.get(&url).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
anyhow::bail!("GET /v1/models returned {}", resp.status());
|
||||
}
|
||||
|
||||
let models_resp: ModelsResponse = resp.json().await?;
|
||||
Ok(models_resp
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|m| ModelInfo {
|
||||
id: m.id,
|
||||
harness: "mistralrs".into(),
|
||||
status: m.status.unwrap_or_else(|| "loaded".into()),
|
||||
devices: vec![],
|
||||
vram_used_mb: None,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn load_model(&self, spec: &ModelSpec) -> Result<()> {
|
||||
let url = format!("{}/v1/models/reload", self.endpoint);
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.json(&serde_json::json!({ "model_id": spec.model_id }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("POST /v1/models/reload failed: {body}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unload_model(&self, model_id: &str) -> Result<()> {
|
||||
let url = format!("{}/v1/models/unload", self.endpoint);
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.json(&serde_json::json!({ "model_id": model_id }))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("POST /v1/models/unload failed: {body}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn inference_endpoint(&self, _model_id: &str) -> Option<String> {
|
||||
// mistral.rs routes internally by model name in the request body,
|
||||
// so the inference endpoint is always the base URL.
|
||||
Some(self.endpoint.clone())
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user