Compare commits
38 Commits
phase-2-pr
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
60f5598542
|
|||
|
7945240646
|
|||
|
0c74d89d15
|
|||
|
c94a2ae755
|
|||
|
99920dd322
|
|||
|
c4f239ceb9
|
|||
|
ac445c1569
|
|||
|
abc6e605b8
|
|||
|
4f2957af9e
|
|||
|
75cd088b61
|
|||
|
d311c8ca7a
|
|||
|
c97a8654f5
|
|||
|
dc048ffcc9
|
|||
|
7ebcfba5ca
|
|||
|
825bf4e905
|
|||
|
4c12c7e2f0
|
|||
|
ba1b5ba408
|
|||
|
5731f4c318
|
|||
|
fa013505d1
|
|||
|
c8bcaabc38
|
|||
|
7ad56c6a86
|
|||
|
1b0e36c119
|
|||
|
ed2d09864e
|
|||
|
4994b94c84
|
|||
|
9a24b05866
|
|||
|
7bb033b4ed
|
|||
|
f8c0da0ebf
|
|||
|
dd592d918d
|
|||
|
766c20ba47
|
|||
|
4972c7d1e7
|
|||
|
a26bb9f04b
|
|||
|
ea1fdf8aa6
|
|||
|
577781de8d
|
|||
|
24968e9233
|
|||
|
7df84fed8f
|
|||
|
5c520c7e90
|
|||
|
d0292ed377
|
|||
|
d4e1b05956
|
146
.gitea/workflows/deploy.yml
Normal file
146
.gitea/workflows/deploy.yml
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
name: deploy
|
||||||
|
|
||||||
|
# Roll the freshly-published unstable RPMs onto the helexa fleet:
|
||||||
|
# cortex on the gateway, helexa-neuron-<flavour> on each neuron host.
|
||||||
|
#
|
||||||
|
# Triggered automatically after `build-prerelease` succeeds (by which
|
||||||
|
# point the new RPMs are live on rpm.lair.cafe/unstable), and also
|
||||||
|
# re-runnable manually from the Gitea UI.
|
||||||
|
#
|
||||||
|
# Per-host one-time setup (gitea_ci user, authorized_keys, scoped
|
||||||
|
# sudoers drop-in) lives in script/infra-setup.sh — run that once per
|
||||||
|
# host before this workflow can succeed.
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_run:
|
||||||
|
workflows: [build-prerelease]
|
||||||
|
types: [completed]
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
# Serialize deploys. Overlapping runs would race on dnf metadata
|
||||||
|
# refresh and service-restart timing; queueing keeps the fleet
|
||||||
|
# predictable. Don't cancel an in-flight deploy — a half-applied dnf
|
||||||
|
# transaction is worse than a slightly stale deploy.
|
||||||
|
concurrency:
|
||||||
|
group: deploy
|
||||||
|
cancel-in-progress: false
|
||||||
|
|
||||||
|
env:
|
||||||
|
DEPLOY_KEY: |
|
||||||
|
${{ secrets.RSYNC_SSH_KEY }}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
deploy-cortex:
|
||||||
|
runs-on: fedora-43
|
||||||
|
# Two trigger paths: manual dispatch always runs; workflow_run
|
||||||
|
# only runs if the upstream `build-prerelease` actually succeeded.
|
||||||
|
if: >-
|
||||||
|
${{
|
||||||
|
github.event_name == 'workflow_dispatch'
|
||||||
|
|| github.event.workflow_run.conclusion == 'success'
|
||||||
|
}}
|
||||||
|
steps:
|
||||||
|
- name: SSH init
|
||||||
|
run: |
|
||||||
|
mkdir -p ~/.ssh
|
||||||
|
echo "${DEPLOY_KEY}" > ~/.ssh/id_ed25519
|
||||||
|
chmod 600 ~/.ssh/id_ed25519
|
||||||
|
ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=accept-new \
|
||||||
|
gitea_ci@hanzalova.internal 'hostname -f'
|
||||||
|
|
||||||
|
- name: Stop cortex.service
|
||||||
|
run: |
|
||||||
|
ssh gitea_ci@hanzalova.internal '
|
||||||
|
if systemctl is-active --quiet cortex.service; then
|
||||||
|
sudo /usr/bin/systemctl stop cortex.service
|
||||||
|
fi'
|
||||||
|
|
||||||
|
- name: Install / upgrade cortex from rpm.lair.cafe/unstable
|
||||||
|
run: |
|
||||||
|
ssh gitea_ci@hanzalova.internal '
|
||||||
|
if rpm -q cortex >/dev/null 2>&1; then
|
||||||
|
sudo /usr/bin/dnf upgrade --refresh --allowerasing -y cortex
|
||||||
|
else
|
||||||
|
sudo /usr/bin/dnf install --refresh --allowerasing -y cortex
|
||||||
|
fi'
|
||||||
|
|
||||||
|
- name: Start cortex.service
|
||||||
|
run: |
|
||||||
|
ssh gitea_ci@hanzalova.internal '
|
||||||
|
sudo /usr/bin/systemctl daemon-reload
|
||||||
|
sudo /usr/bin/systemctl start cortex.service'
|
||||||
|
|
||||||
|
# Wait for the service to either come up or wedge, then capture
|
||||||
|
# the latest-invocation journal. Runs even on prior failure so a
|
||||||
|
# failed start step still leaves a usable record in the deploy log.
|
||||||
|
- name: Capture cortex.service startup journal
|
||||||
|
if: always()
|
||||||
|
run: |
|
||||||
|
sleep 10
|
||||||
|
ssh gitea_ci@hanzalova.internal \
|
||||||
|
'journalctl --unit cortex.service -I --no-pager'
|
||||||
|
|
||||||
|
deploy-neurons:
|
||||||
|
needs: [deploy-cortex]
|
||||||
|
runs-on: fedora-43
|
||||||
|
strategy:
|
||||||
|
# One neuron failing must not cancel the others. Cortex is up
|
||||||
|
# already; a partial neuron deploy is strictly better than
|
||||||
|
# rolling back to zero.
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- host: beast.hanzalova.internal
|
||||||
|
flavour: blackwell
|
||||||
|
- host: benjy.hanzalova.internal
|
||||||
|
flavour: ada
|
||||||
|
- host: quadbrat.hanzalova.internal
|
||||||
|
flavour: ampere
|
||||||
|
steps:
|
||||||
|
- name: SSH init
|
||||||
|
run: |
|
||||||
|
mkdir -p ~/.ssh
|
||||||
|
echo "${DEPLOY_KEY}" > ~/.ssh/id_ed25519
|
||||||
|
chmod 600 ~/.ssh/id_ed25519
|
||||||
|
ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=accept-new \
|
||||||
|
gitea_ci@${{ matrix.host }} 'hostname -f'
|
||||||
|
|
||||||
|
- name: Stop neuron.service
|
||||||
|
run: |
|
||||||
|
ssh gitea_ci@${{ matrix.host }} '
|
||||||
|
if systemctl is-active --quiet neuron.service; then
|
||||||
|
sudo /usr/bin/systemctl stop neuron.service
|
||||||
|
fi'
|
||||||
|
|
||||||
|
- name: Install / upgrade helexa-neuron-${{ matrix.flavour }}
|
||||||
|
run: |
|
||||||
|
ssh gitea_ci@${{ matrix.host }} "
|
||||||
|
if rpm -q helexa-neuron-${{ matrix.flavour }} >/dev/null 2>&1; then
|
||||||
|
sudo /usr/bin/dnf upgrade --refresh --allowerasing -y helexa-neuron-${{ matrix.flavour }}
|
||||||
|
else
|
||||||
|
sudo /usr/bin/dnf install --refresh --allowerasing -y helexa-neuron-${{ matrix.flavour }}
|
||||||
|
fi"
|
||||||
|
|
||||||
|
- name: Ensure firewalld allows helexa-neuron
|
||||||
|
run: |
|
||||||
|
ssh gitea_ci@${{ matrix.host }} '
|
||||||
|
if ! sudo /usr/bin/firewall-cmd --query-service=helexa-neuron --quiet 2>/dev/null; then
|
||||||
|
sudo /usr/bin/firewall-cmd --add-service=helexa-neuron --permanent
|
||||||
|
sudo /usr/bin/firewall-cmd --reload
|
||||||
|
fi'
|
||||||
|
|
||||||
|
- name: Start neuron.service
|
||||||
|
run: |
|
||||||
|
ssh gitea_ci@${{ matrix.host }} '
|
||||||
|
sudo /usr/bin/systemctl daemon-reload
|
||||||
|
sudo /usr/bin/systemctl start neuron.service'
|
||||||
|
|
||||||
|
# Wait for the service to either come up or wedge, then capture
|
||||||
|
# the latest-invocation journal. Runs even on prior failure so a
|
||||||
|
# failed start step still leaves a usable record in the deploy log.
|
||||||
|
- name: Capture neuron.service startup journal
|
||||||
|
if: always()
|
||||||
|
run: |
|
||||||
|
sleep 10
|
||||||
|
ssh gitea_ci@${{ matrix.host }} \
|
||||||
|
'journalctl --unit neuron.service -I --no-pager'
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -7,3 +7,4 @@ cortex.toml
|
|||||||
models.toml
|
models.toml
|
||||||
doc/plan/*
|
doc/plan/*
|
||||||
/target-cuda/
|
/target-cuda/
|
||||||
|
.claude/
|
||||||
|
|||||||
131
Cargo.lock
generated
131
Cargo.lock
generated
@@ -472,6 +472,12 @@ version = "1.5.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "byteorder-lite"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bytes"
|
name = "bytes"
|
||||||
version = "1.11.1"
|
version = "1.11.1"
|
||||||
@@ -668,6 +674,12 @@ dependencies = [
|
|||||||
"cc",
|
"cc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "color_quant"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "colorchoice"
|
name = "colorchoice"
|
||||||
version = "1.0.5"
|
version = "1.0.5"
|
||||||
@@ -893,8 +905,7 @@ dependencies = [
|
|||||||
[[package]]
|
[[package]]
|
||||||
name = "cudarc"
|
name = "cudarc"
|
||||||
version = "0.19.7"
|
version = "0.19.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "git+https://github.com/grenade/cudarc?rev=63327a256059f8252641ae46c6bb9eefe707f382#63327a256059f8252641ae46c6bb9eefe707f382"
|
||||||
checksum = "1cea5f10a99e025c1b44ae2354c2d8326b25ddbd0baf76bde8e55cfd4018a2cc"
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"float8",
|
"float8",
|
||||||
"half",
|
"half",
|
||||||
@@ -1223,6 +1234,15 @@ version = "2.4.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6"
|
checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fdeflate"
|
||||||
|
version = "0.3.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c"
|
||||||
|
dependencies = [
|
||||||
|
"simd-adler32",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "figment"
|
name = "figment"
|
||||||
version = "0.10.19"
|
version = "0.10.19"
|
||||||
@@ -1731,6 +1751,16 @@ dependencies = [
|
|||||||
"wasip3",
|
"wasip3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gif"
|
||||||
|
version = "0.14.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ee8cfcc411d9adbbaba82fb72661cc1bcca13e8bba98b364e62b2dba8f960159"
|
||||||
|
dependencies = [
|
||||||
|
"color_quant",
|
||||||
|
"weezl",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "glob"
|
name = "glob"
|
||||||
version = "0.3.3"
|
version = "0.3.3"
|
||||||
@@ -2135,6 +2165,34 @@ dependencies = [
|
|||||||
"icu_properties",
|
"icu_properties",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "image"
|
||||||
|
version = "0.25.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "85ab80394333c02fe689eaf900ab500fbd0c2213da414687ebf995a65d5a6104"
|
||||||
|
dependencies = [
|
||||||
|
"bytemuck",
|
||||||
|
"byteorder-lite",
|
||||||
|
"color_quant",
|
||||||
|
"gif",
|
||||||
|
"image-webp",
|
||||||
|
"moxcms",
|
||||||
|
"num-traits",
|
||||||
|
"png",
|
||||||
|
"zune-core",
|
||||||
|
"zune-jpeg",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "image-webp"
|
||||||
|
version = "0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "525e9ff3e1a4be2fbea1fdf0e98686a6d98b4d8f937e1bf7402245af1909e8c3"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder-lite",
|
||||||
|
"quick-error",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "indexmap"
|
name = "indexmap"
|
||||||
version = "1.9.3"
|
version = "1.9.3"
|
||||||
@@ -2449,6 +2507,16 @@ dependencies = [
|
|||||||
"serde_json",
|
"serde_json",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "minijinja-contrib"
|
||||||
|
version = "2.20.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "99df5123c54391e2a228014c1dbbd85a3dab08a25e776c810526f2f47542b3de"
|
||||||
|
dependencies = [
|
||||||
|
"minijinja",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "minimal-lexical"
|
name = "minimal-lexical"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
@@ -2498,6 +2566,16 @@ dependencies = [
|
|||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "moxcms"
|
||||||
|
version = "0.8.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bb85c154ba489f01b25c0d36ae69a87e4a1c73a72631fc6c0eb6dde34a73e44b"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
"pxfm",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "native-tls"
|
name = "native-tls"
|
||||||
version = "0.2.18"
|
version = "0.2.18"
|
||||||
@@ -2522,6 +2600,7 @@ dependencies = [
|
|||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum",
|
"axum",
|
||||||
|
"base64 0.22.1",
|
||||||
"candle-core",
|
"candle-core",
|
||||||
"candle-nn",
|
"candle-nn",
|
||||||
"candle-transformers",
|
"candle-transformers",
|
||||||
@@ -2533,7 +2612,9 @@ dependencies = [
|
|||||||
"futures",
|
"futures",
|
||||||
"half",
|
"half",
|
||||||
"hf-hub",
|
"hf-hub",
|
||||||
|
"image",
|
||||||
"minijinja",
|
"minijinja",
|
||||||
|
"minijinja-contrib",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"safetensors 0.7.0",
|
"safetensors 0.7.0",
|
||||||
"serde",
|
"serde",
|
||||||
@@ -2861,6 +2942,19 @@ version = "0.3.33"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e"
|
checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "png"
|
||||||
|
version = "0.18.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "60769b8b31b2a9f263dae2776c37b1b28ae246943cf719eb6946a1db05128a61"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
"crc32fast",
|
||||||
|
"fdeflate",
|
||||||
|
"flate2",
|
||||||
|
"miniz_oxide",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "polling"
|
name = "polling"
|
||||||
version = "3.11.0"
|
version = "3.11.0"
|
||||||
@@ -2974,6 +3068,12 @@ version = "0.1.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "40e24eee682d89fb193496edf918a7f407d30175b2e785fe057e4392dfd182e0"
|
checksum = "40e24eee682d89fb193496edf918a7f407d30175b2e785fe057e4392dfd182e0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pxfm"
|
||||||
|
version = "0.1.29"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e0c5ccf5294c6ccd63a74f1565028353830a9c2f5eb0c682c355c471726a6e3f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quanta"
|
name = "quanta"
|
||||||
version = "0.12.6"
|
version = "0.12.6"
|
||||||
@@ -2989,6 +3089,12 @@ dependencies = [
|
|||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quick-error"
|
||||||
|
version = "2.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quinn"
|
name = "quinn"
|
||||||
version = "0.11.9"
|
version = "0.11.9"
|
||||||
@@ -4627,6 +4733,12 @@ dependencies = [
|
|||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "weezl"
|
||||||
|
version = "0.1.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "which"
|
name = "which"
|
||||||
version = "7.0.3"
|
version = "7.0.3"
|
||||||
@@ -5164,3 +5276,18 @@ name = "zmij"
|
|||||||
version = "1.0.21"
|
version = "1.0.21"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa"
|
checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zune-core"
|
||||||
|
version = "0.5.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cb8a0807f7c01457d0379ba880ba6322660448ddebc890ce29bb64da71fb40f9"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zune-jpeg"
|
||||||
|
version = "0.5.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "27bc9d5b815bc103f142aa054f561d9187d191692ec7c2d1e2b4737f8dbd7296"
|
||||||
|
dependencies = [
|
||||||
|
"zune-core",
|
||||||
|
]
|
||||||
|
|||||||
@@ -61,3 +61,12 @@ eventsource-stream = "0.2"
|
|||||||
# workspace crates
|
# workspace crates
|
||||||
cortex-core = { path = "crates/cortex-core" }
|
cortex-core = { path = "crates/cortex-core" }
|
||||||
cortex-gateway = { path = "crates/cortex-gateway" }
|
cortex-gateway = { path = "crates/cortex-gateway" }
|
||||||
|
|
||||||
|
# Patched cudarc (affects neuron's 0.19.x only; candle's 0.17.x is
|
||||||
|
# untouched since the fork is 0.19.7 and doesn't satisfy a 0.17 req). Adds
|
||||||
|
# Comm::abort / get_async_error / raw comm() — needed for #17 Stage 2 TP
|
||||||
|
# hang-recovery (abort a wedged collective from another thread, then
|
||||||
|
# rebuild the comm). Pinned to a fork revision pending upstream review
|
||||||
|
# (grenade/cudarc @ nccl-comm-abort).
|
||||||
|
[patch.crates-io]
|
||||||
|
cudarc = { git = "https://github.com/grenade/cudarc", rev = "63327a256059f8252641ae46c6bb9eefe707f382" }
|
||||||
|
|||||||
@@ -1,30 +0,0 @@
|
|||||||
# Helexa fleet manifest.
|
|
||||||
#
|
|
||||||
# Drives rolling deploys via script/deploy.sh and serves as the source
|
|
||||||
# of truth for which hosts run cortex vs neuron, and which CUDA
|
|
||||||
# compute-capability flavour each neuron host needs.
|
|
||||||
#
|
|
||||||
# Flavour ↔ NVIDIA generation ↔ compute cap:
|
|
||||||
# ampere sm_86 (RTX 30 series — e.g. 3060)
|
|
||||||
# ada sm_89 (RTX 40 series — e.g. 4090)
|
|
||||||
# blackwell sm_120 (RTX 50 series — e.g. 5090)
|
|
||||||
#
|
|
||||||
# The flavour determines which RPM is installed on a given neuron host:
|
|
||||||
# helexa-neuron-<flavour>. Only one flavour may be installed at a time
|
|
||||||
# (the packages Conflict: with each other).
|
|
||||||
|
|
||||||
cortex:
|
|
||||||
host: hanzalova.internal
|
|
||||||
|
|
||||||
neurons:
|
|
||||||
- host: beast.hanzalova.internal
|
|
||||||
flavour: blackwell
|
|
||||||
gpu: "2x RTX 5090"
|
|
||||||
|
|
||||||
- host: benjy.hanzalova.internal
|
|
||||||
flavour: ada
|
|
||||||
gpu: "RTX 4090"
|
|
||||||
|
|
||||||
- host: quadbrat.hanzalova.internal
|
|
||||||
flavour: ampere
|
|
||||||
gpu: "RTX 3060"
|
|
||||||
@@ -5,9 +5,9 @@
|
|||||||
# invocation: `validate-neuron.sh beast.hanzalova.internal
|
# invocation: `validate-neuron.sh beast.hanzalova.internal
|
||||||
# Qwen/Qwen3.6-27B q5k 2`.
|
# Qwen/Qwen3.6-27B q5k 2`.
|
||||||
#
|
#
|
||||||
# Synced by script/deploy.sh from asset/neuron/<short-host>.toml. Edits
|
# Synced to /etc/neuron/neuron.toml by script/infra-setup.sh. Edits
|
||||||
# take effect on the next deploy.sh run (which stops + restarts the
|
# take effect after the next deploy workflow run restarts the service
|
||||||
# service so default_models is re-read at activation).
|
# (default_models is read at activation).
|
||||||
|
|
||||||
port = 13131
|
port = 13131
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
# Qwen3-8B (bf16, ~18 GB), leaving ~6 GB for KV cache + activations on
|
# Qwen3-8B (bf16, ~18 GB), leaving ~6 GB for KV cache + activations on
|
||||||
# moderate-length contexts.
|
# moderate-length contexts.
|
||||||
#
|
#
|
||||||
# Synced by script/deploy.sh from asset/neuron/<short-host>.toml.
|
# Synced to /etc/neuron/neuron.toml by script/infra-setup.sh.
|
||||||
|
|
||||||
port = 13131
|
port = 13131
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
# (bf16, ~4 GB), leaving ~7 GB for KV cache so long contexts on a small
|
# (bf16, ~4 GB), leaving ~7 GB for KV cache so long contexts on a small
|
||||||
# model still have plenty of room.
|
# model still have plenty of room.
|
||||||
#
|
#
|
||||||
# Synced by script/deploy.sh from asset/neuron/<short-host>.toml.
|
# Synced to /etc/neuron/neuron.toml by script/infra-setup.sh.
|
||||||
|
|
||||||
port = 13131
|
port = 13131
|
||||||
|
|
||||||
|
|||||||
20
asset/sudoers.d/cortex-host.conf
Normal file
20
asset/sudoers.d/cortex-host.conf
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# Install on the cortex gateway host as /etc/sudoers.d/helexa_gitea_ci
|
||||||
|
# (owner root:root, mode 0440). Required by .gitea/workflows/deploy.yml,
|
||||||
|
# which SSHes as gitea_ci@<gateway> to roll out cortex package upgrades
|
||||||
|
# and config changes.
|
||||||
|
#
|
||||||
|
# Filename convention `helexa_gitea_ci` (vs bare `gitea_ci`) so other
|
||||||
|
# helexa-org apps can drop their own sudoers files on the same host
|
||||||
|
# without overwriting this one.
|
||||||
|
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/rsync * /etc/cortex/cortex.toml
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/rsync * /etc/cortex/models.toml
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl start cortex.service
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl stop cortex.service
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl daemon-reload
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf install --refresh --allowerasing -y cortex
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf upgrade --refresh --allowerasing -y cortex
|
||||||
|
# sudoers reserves `:` and `=` and requires `\` escaping inside command
|
||||||
|
# arguments — without it visudo errors at the first `:` in `https://`.
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf config-manager addrepo --from-repofile\=https\://rpm.lair.cafe/lair-cafe-unstable.repo
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf config-manager setopt lair-cafe-unstable.enabled\=1
|
||||||
33
asset/sudoers.d/neuron-host.conf
Normal file
33
asset/sudoers.d/neuron-host.conf
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# Install on every neuron host as /etc/sudoers.d/helexa_gitea_ci
|
||||||
|
# (owner root:root, mode 0440). Required by .gitea/workflows/deploy.yml,
|
||||||
|
# which SSHes as gitea_ci@<neuron-host> to roll out helexa-neuron-<flavour>
|
||||||
|
# package upgrades and config changes.
|
||||||
|
#
|
||||||
|
# Filename convention `helexa_gitea_ci` (vs bare `gitea_ci`) so other
|
||||||
|
# helexa-org apps can drop their own sudoers files on the same host
|
||||||
|
# without overwriting this one.
|
||||||
|
#
|
||||||
|
# All three CUDA flavours are listed because a host's flavour can change
|
||||||
|
# (e.g. GPU swap) and we don't want the sudoers file to need to change
|
||||||
|
# in lockstep. Only one flavour can be installed at a time (the packages
|
||||||
|
# Conflict: with each other), so the attack surface is bounded to "wrong
|
||||||
|
# flavour installed" — vandalism, not privilege escalation.
|
||||||
|
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/rsync * /etc/neuron/neuron.toml
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl start neuron.service
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl stop neuron.service
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/systemctl daemon-reload
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf install --refresh --allowerasing -y helexa-neuron-ampere
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf upgrade --refresh --allowerasing -y helexa-neuron-ampere
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf install --refresh --allowerasing -y helexa-neuron-ada
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf upgrade --refresh --allowerasing -y helexa-neuron-ada
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf install --refresh --allowerasing -y helexa-neuron-blackwell
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf upgrade --refresh --allowerasing -y helexa-neuron-blackwell
|
||||||
|
# sudoers reserves `:` and `=` and requires `\` escaping inside command
|
||||||
|
# arguments — without it visudo errors at the first `:` in `https://`.
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf config-manager addrepo --from-repofile\=https\://rpm.lair.cafe/lair-cafe-unstable.repo
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf config-manager setopt lair-cafe-unstable.enabled\=1
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf config-manager addrepo --from-repofile\=https\://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/dnf install -y libcudnn9-cuda-13
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/firewall-cmd --add-service=helexa-neuron --permanent
|
||||||
|
gitea_ci ALL=(root) NOPASSWD: /usr/bin/firewall-cmd --reload
|
||||||
@@ -24,6 +24,17 @@ pub struct ModelProfile {
|
|||||||
/// Neurons where this model should never be evicted.
|
/// Neurons where this model should never be evicted.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub pinned_on: Vec<String>,
|
pub pinned_on: Vec<String>,
|
||||||
|
/// Source scheme this profile's weights come from. When set, the
|
||||||
|
/// router prefixes `id` with `scheme:` before forwarding the load
|
||||||
|
/// request to neuron, ensuring the daemon fetches from the right
|
||||||
|
/// registry regardless of which entry happens to match `id`.
|
||||||
|
///
|
||||||
|
/// `None` lets neuron substitute its own `default_source` (typically
|
||||||
|
/// `huggingface`). Set to `"helexa"` when the model is hosted in
|
||||||
|
/// the helexa registry — operator-procurement-grade audit relies
|
||||||
|
/// on this being explicit per model rather than implicit.
|
||||||
|
#[serde(default)]
|
||||||
|
pub source: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_min_devices() -> u32 {
|
fn default_min_devices() -> u32 {
|
||||||
@@ -140,6 +151,7 @@ mod tests {
|
|||||||
min_devices: 2,
|
min_devices: 2,
|
||||||
min_device_vram_mb: Some(24_000),
|
min_device_vram_mb: Some(24_000),
|
||||||
pinned_on: vec![],
|
pinned_on: vec![],
|
||||||
|
source: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -197,6 +209,29 @@ mod tests {
|
|||||||
assert_eq!(cat.resolve_alias("Qwen/Qwen3-8B"), "Qwen/Qwen3-8B");
|
assert_eq!(cat.resolve_alias("Qwen/Qwen3-8B"), "Qwen/Qwen3-8B");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn source_defaults_to_none_when_absent_from_toml() {
|
||||||
|
let src = r#"
|
||||||
|
[[models]]
|
||||||
|
id = "Qwen/Qwen3-30B"
|
||||||
|
harness = "candle"
|
||||||
|
"#;
|
||||||
|
let cat: ModelCatalogue = toml::from_str(src).expect("parse models table");
|
||||||
|
assert!(cat.models[0].source.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn source_round_trips_through_toml() {
|
||||||
|
let src = r#"
|
||||||
|
[[models]]
|
||||||
|
id = "Helexa/Qwen3.6-27B-Uncensored"
|
||||||
|
harness = "candle"
|
||||||
|
source = "helexa"
|
||||||
|
"#;
|
||||||
|
let cat: ModelCatalogue = toml::from_str(src).expect("parse models table");
|
||||||
|
assert_eq!(cat.models[0].source.as_deref(), Some("helexa"));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn aliases_table_round_trips_through_toml() {
|
fn aliases_table_round_trips_through_toml() {
|
||||||
let src = r#"
|
let src = r#"
|
||||||
|
|||||||
@@ -44,6 +44,16 @@ pub struct ModelInfo {
|
|||||||
pub status: String,
|
pub status: String,
|
||||||
pub devices: Vec<u32>,
|
pub devices: Vec<u32>,
|
||||||
pub vram_used_mb: Option<u64>,
|
pub vram_used_mb: Option<u64>,
|
||||||
|
/// Modalities this loaded model supports. Today: `["text"]` for
|
||||||
|
/// text-only checkpoints, `["text", "vision"]` for vision-capable
|
||||||
|
/// ones (Stage B7 of the vision plan). Clients like litellm /
|
||||||
|
/// agent0 can gate `image_url` submission on the advertised set.
|
||||||
|
///
|
||||||
|
/// Optional in the wire format so older clients that don't read
|
||||||
|
/// it stay compatible. Default-empty for absent/older data, which
|
||||||
|
/// callers can interpret as "text".
|
||||||
|
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||||
|
pub capabilities: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// What an inference harness must do, from neuron's perspective.
|
/// What an inference harness must do, from neuron's perspective.
|
||||||
|
|||||||
@@ -7,4 +7,5 @@ pub mod metrics;
|
|||||||
pub mod node;
|
pub mod node;
|
||||||
pub mod openai;
|
pub mod openai;
|
||||||
pub mod responses;
|
pub mod responses;
|
||||||
|
pub mod source;
|
||||||
pub mod translate;
|
pub mod translate;
|
||||||
|
|||||||
@@ -37,6 +37,12 @@ pub struct ModelEntry {
|
|||||||
pub last_accessed: Option<DateTime<Utc>>,
|
pub last_accessed: Option<DateTime<Utc>>,
|
||||||
/// Estimated VRAM usage in MB when loaded.
|
/// Estimated VRAM usage in MB when loaded.
|
||||||
pub vram_estimate_mb: Option<u64>,
|
pub vram_estimate_mb: Option<u64>,
|
||||||
|
/// Modalities the loaded model advertises (e.g. `["text", "vision"]`),
|
||||||
|
/// copied verbatim from the neuron's `ModelInfo.capabilities` at poll
|
||||||
|
/// time. Empty when the neuron reports none. `#[serde(default)]` keeps
|
||||||
|
/// older persisted/serialised entries deserialisable.
|
||||||
|
#[serde(default)]
|
||||||
|
pub capabilities: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Model lifecycle status.
|
/// Model lifecycle status.
|
||||||
@@ -85,6 +91,12 @@ pub struct CortexModelEntry {
|
|||||||
/// disjoint from) `feasible_on` depending on whether the catalogue
|
/// disjoint from) `feasible_on` depending on whether the catalogue
|
||||||
/// covers this model.
|
/// covers this model.
|
||||||
pub locations: Vec<ModelLocation>,
|
pub locations: Vec<ModelLocation>,
|
||||||
|
/// Union of the modalities advertised by every neuron that has this
|
||||||
|
/// model loaded (e.g. `["text", "vision"]`). Empty for catalogue-only
|
||||||
|
/// entries with no loaded location — the catalogue profile doesn't
|
||||||
|
/// declare capabilities yet (tracked separately from C3).
|
||||||
|
#[serde(default)]
|
||||||
|
pub capabilities: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|||||||
267
crates/cortex-core/src/source.rs
Normal file
267
crates/cortex-core/src/source.rs
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
//! Scheme-qualified model identifiers.
|
||||||
|
//!
|
||||||
|
//! cortex/neuron historically resolves every model id through hf-hub
|
||||||
|
//! against `https://huggingface.co`. Helexa is adding an EU-hosted
|
||||||
|
//! registry (`registry.helexa.ai`) alongside HF — both speak the same
|
||||||
|
//! HF-compatible wire format, but the bytes, jurisdiction, and trust
|
||||||
|
//! root differ. Model ids therefore need a scheme:
|
||||||
|
//!
|
||||||
|
//! - `huggingface:Qwen/Qwen3.6-27B` — HF-hosted bytes
|
||||||
|
//! - `helexa:Qwen/Qwen3.6-27B-Uncensored` — helexa registry bytes
|
||||||
|
//! - `helexa:SomeOperator/CustomFinetune` — operator publishing
|
||||||
|
//! under the helexa namespace; same scheme handles all `org/name`
|
||||||
|
//! pairs hosted in that registry.
|
||||||
|
//!
|
||||||
|
//! Bare `org/name` parses with an empty scheme; the caller (typically
|
||||||
|
//! a harness) substitutes its configured default scheme so existing
|
||||||
|
//! configs keep working through the transition.
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::fmt;
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
/// Parsed `scheme:org/name`. Bare `org/name` produces an empty scheme
|
||||||
|
/// — call `with_default_scheme` (or check `is_scheme_unset`) to
|
||||||
|
/// resolve before using.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||||
|
pub struct ModelSourceId {
|
||||||
|
pub scheme: String,
|
||||||
|
pub org: String,
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Errors from `ModelSourceId::from_str`. Carries the offending input
|
||||||
|
/// so log lines / API errors can echo what the operator typed.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
|
||||||
|
pub enum ParseError {
|
||||||
|
#[error("empty model id")]
|
||||||
|
Empty,
|
||||||
|
#[error("model id '{0}' is missing the '/' between org and name")]
|
||||||
|
MissingSlash(String),
|
||||||
|
#[error("model id '{0}' has an empty scheme before ':'")]
|
||||||
|
EmptyScheme(String),
|
||||||
|
#[error("model id '{0}' has an empty org")]
|
||||||
|
EmptyOrg(String),
|
||||||
|
#[error("model id '{0}' has an empty name")]
|
||||||
|
EmptyName(String),
|
||||||
|
#[error("model id '{0}' has a scheme containing '/' which is reserved for org/name")]
|
||||||
|
SchemeContainsSlash(String),
|
||||||
|
#[error("model id '{0}' has a name containing ':' which is reserved for the scheme prefix")]
|
||||||
|
NameContainsColon(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelSourceId {
|
||||||
|
/// Construct directly from already-validated parts. Used by tests
|
||||||
|
/// and call sites that have the fields separately; the public API
|
||||||
|
/// for parsing user input is `FromStr`.
|
||||||
|
pub fn new(scheme: impl Into<String>, org: impl Into<String>, name: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
scheme: scheme.into(),
|
||||||
|
org: org.into(),
|
||||||
|
name: name.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// True when this id parsed from a bare `org/name` (no scheme
|
||||||
|
/// prefix). The harness substitutes its configured default in
|
||||||
|
/// `with_default_scheme` before resolving against a registry.
|
||||||
|
pub fn is_scheme_unset(&self) -> bool {
|
||||||
|
self.scheme.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Substitute `default` for an empty scheme. No-op when the scheme
|
||||||
|
/// is already set. Returns self by value so it composes neatly:
|
||||||
|
/// `id.parse::<ModelSourceId>()?.with_default_scheme("huggingface")`.
|
||||||
|
pub fn with_default_scheme(mut self, default: &str) -> Self {
|
||||||
|
if self.scheme.is_empty() {
|
||||||
|
self.scheme = default.to_string();
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The `org/name` half — what an hf-hub `Api::model(...)` call
|
||||||
|
/// expects regardless of which scheme/endpoint we're hitting.
|
||||||
|
pub fn repo_path(&self) -> String {
|
||||||
|
format!("{}/{}", self.org, self.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for ModelSourceId {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
if self.scheme.is_empty() {
|
||||||
|
write!(f, "{}/{}", self.org, self.name)
|
||||||
|
} else {
|
||||||
|
write!(f, "{}:{}/{}", self.scheme, self.org, self.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromStr for ModelSourceId {
|
||||||
|
type Err = ParseError;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
|
if s.is_empty() {
|
||||||
|
return Err(ParseError::Empty);
|
||||||
|
}
|
||||||
|
// Scheme split. Only the *first* colon counts — anything after
|
||||||
|
// belongs to org/name (and would be rejected separately because
|
||||||
|
// `:` isn't allowed there).
|
||||||
|
let (scheme, rest) = match s.split_once(':') {
|
||||||
|
Some((scheme, rest)) => {
|
||||||
|
if scheme.is_empty() {
|
||||||
|
return Err(ParseError::EmptyScheme(s.to_string()));
|
||||||
|
}
|
||||||
|
if scheme.contains('/') {
|
||||||
|
return Err(ParseError::SchemeContainsSlash(s.to_string()));
|
||||||
|
}
|
||||||
|
(scheme.to_string(), rest)
|
||||||
|
}
|
||||||
|
None => (String::new(), s),
|
||||||
|
};
|
||||||
|
let (org, name) = rest
|
||||||
|
.split_once('/')
|
||||||
|
.ok_or_else(|| ParseError::MissingSlash(s.to_string()))?;
|
||||||
|
if org.is_empty() {
|
||||||
|
return Err(ParseError::EmptyOrg(s.to_string()));
|
||||||
|
}
|
||||||
|
if name.is_empty() {
|
||||||
|
return Err(ParseError::EmptyName(s.to_string()));
|
||||||
|
}
|
||||||
|
if name.contains(':') {
|
||||||
|
return Err(ParseError::NameContainsColon(s.to_string()));
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
scheme,
|
||||||
|
org: org.to_string(),
|
||||||
|
name: name.to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_qualified() {
|
||||||
|
let id: ModelSourceId = "huggingface:Qwen/Qwen3.6-27B".parse().unwrap();
|
||||||
|
assert_eq!(id.scheme, "huggingface");
|
||||||
|
assert_eq!(id.org, "Qwen");
|
||||||
|
assert_eq!(id.name, "Qwen3.6-27B");
|
||||||
|
assert_eq!(id.repo_path(), "Qwen/Qwen3.6-27B");
|
||||||
|
assert!(!id.is_scheme_unset());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_helexa_scheme() {
|
||||||
|
let id: ModelSourceId = "helexa:SomeOperator/Qwen3.6-27B-Uncensored"
|
||||||
|
.parse()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(id.scheme, "helexa");
|
||||||
|
assert_eq!(id.org, "SomeOperator");
|
||||||
|
assert_eq!(id.name, "Qwen3.6-27B-Uncensored");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_bare_id_with_empty_scheme() {
|
||||||
|
let id: ModelSourceId = "Qwen/Qwen3-30B-A3B-Instruct".parse().unwrap();
|
||||||
|
assert_eq!(id.scheme, "");
|
||||||
|
assert_eq!(id.org, "Qwen");
|
||||||
|
assert_eq!(id.name, "Qwen3-30B-A3B-Instruct");
|
||||||
|
assert!(id.is_scheme_unset());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn substitutes_default_scheme_only_when_unset() {
|
||||||
|
let id: ModelSourceId = "Qwen/Q3".parse().unwrap();
|
||||||
|
assert_eq!(id.with_default_scheme("huggingface").scheme, "huggingface");
|
||||||
|
|
||||||
|
let id: ModelSourceId = "helexa:Qwen/Q3".parse().unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
id.with_default_scheme("huggingface").scheme,
|
||||||
|
"helexa",
|
||||||
|
"default substitution must not override an explicit scheme"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn display_roundtrips_qualified_id() {
|
||||||
|
let s = "helexa:Helexa/Qwen3.6-27B";
|
||||||
|
let id: ModelSourceId = s.parse().unwrap();
|
||||||
|
assert_eq!(id.to_string(), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn display_roundtrips_bare_id() {
|
||||||
|
let s = "Qwen/Q3";
|
||||||
|
let id: ModelSourceId = s.parse().unwrap();
|
||||||
|
assert_eq!(id.to_string(), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_empty() {
|
||||||
|
assert_eq!("".parse::<ModelSourceId>().unwrap_err(), ParseError::Empty);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_missing_slash() {
|
||||||
|
match "Qwen".parse::<ModelSourceId>().unwrap_err() {
|
||||||
|
ParseError::MissingSlash(s) => assert_eq!(s, "Qwen"),
|
||||||
|
other => panic!("expected MissingSlash, got {other:?}"),
|
||||||
|
}
|
||||||
|
match "huggingface:Qwen".parse::<ModelSourceId>().unwrap_err() {
|
||||||
|
ParseError::MissingSlash(s) => assert_eq!(s, "huggingface:Qwen"),
|
||||||
|
other => panic!("expected MissingSlash, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_empty_scheme() {
|
||||||
|
match ":Qwen/Q3".parse::<ModelSourceId>().unwrap_err() {
|
||||||
|
ParseError::EmptyScheme(s) => assert_eq!(s, ":Qwen/Q3"),
|
||||||
|
other => panic!("expected EmptyScheme, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_scheme_with_slash() {
|
||||||
|
match "hugg/ingface:Q/N".parse::<ModelSourceId>().unwrap_err() {
|
||||||
|
ParseError::SchemeContainsSlash(s) => assert_eq!(s, "hugg/ingface:Q/N"),
|
||||||
|
other => panic!("expected SchemeContainsSlash, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_empty_org_or_name() {
|
||||||
|
match "huggingface:/N".parse::<ModelSourceId>().unwrap_err() {
|
||||||
|
ParseError::EmptyOrg(_) => {}
|
||||||
|
other => panic!("expected EmptyOrg, got {other:?}"),
|
||||||
|
}
|
||||||
|
match "huggingface:Q/".parse::<ModelSourceId>().unwrap_err() {
|
||||||
|
ParseError::EmptyName(_) => {}
|
||||||
|
other => panic!("expected EmptyName, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_name_with_colon() {
|
||||||
|
match "huggingface:Q/N:weird"
|
||||||
|
.parse::<ModelSourceId>()
|
||||||
|
.unwrap_err()
|
||||||
|
{
|
||||||
|
ParseError::NameContainsColon(s) => assert_eq!(s, "huggingface:Q/N:weird"),
|
||||||
|
other => panic!("expected NameContainsColon, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn serde_roundtrips_via_struct() {
|
||||||
|
// We serialize as a struct (scheme/org/name fields) so the
|
||||||
|
// shape is self-describing in API payloads. Callers that want
|
||||||
|
// the compact `scheme:org/name` string use `Display`/`FromStr`.
|
||||||
|
let id = ModelSourceId::new("helexa", "Helexa", "Qwen3.6-27B");
|
||||||
|
let json = serde_json::to_string(&id).unwrap();
|
||||||
|
let back: ModelSourceId = serde_json::from_str(&json).unwrap();
|
||||||
|
assert_eq!(back, id);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -414,6 +414,9 @@ async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
|||||||
loaded: false,
|
loaded: false,
|
||||||
feasible_on,
|
feasible_on,
|
||||||
locations: Vec::new(),
|
locations: Vec::new(),
|
||||||
|
// Catalogue profiles don't declare capabilities yet;
|
||||||
|
// the union is filled in Pass 2 from loaded locations.
|
||||||
|
capabilities: Vec::new(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -438,6 +441,14 @@ async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
|||||||
if was_loaded {
|
if was_loaded {
|
||||||
e.loaded = true;
|
e.loaded = true;
|
||||||
}
|
}
|
||||||
|
// Union the per-node capabilities so a model loaded
|
||||||
|
// on several neurons reports every modality any of
|
||||||
|
// them advertises.
|
||||||
|
for cap in &entry.capabilities {
|
||||||
|
if !e.capabilities.contains(cap) {
|
||||||
|
e.capabilities.push(cap.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.or_insert_with(|| CortexModelEntry {
|
.or_insert_with(|| CortexModelEntry {
|
||||||
id: model_id.clone(),
|
id: model_id.clone(),
|
||||||
@@ -449,6 +460,7 @@ async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
|||||||
// feasibility; leave empty.
|
// feasibility; leave empty.
|
||||||
feasible_on: Vec::new(),
|
feasible_on: Vec::new(),
|
||||||
locations: vec![location],
|
locations: vec![location],
|
||||||
|
capabilities: entry.capabilities.clone(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -498,6 +510,9 @@ async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
|||||||
loaded: false,
|
loaded: false,
|
||||||
feasible_on: Vec::new(),
|
feasible_on: Vec::new(),
|
||||||
locations: vec![location],
|
locations: vec![location],
|
||||||
|
// A model that's only mid-prewarm has no loaded
|
||||||
|
// location to read capabilities from yet.
|
||||||
|
capabilities: Vec::new(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -527,6 +542,7 @@ async fn list_models(State(fleet): State<Arc<CortexState>>) -> Json<Value> {
|
|||||||
loaded: target_entry.loaded,
|
loaded: target_entry.loaded,
|
||||||
feasible_on: target_entry.feasible_on,
|
feasible_on: target_entry.feasible_on,
|
||||||
locations: target_entry.locations,
|
locations: target_entry.locations,
|
||||||
|
capabilities: target_entry.capabilities,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,12 +107,14 @@ async fn poll_neuron(fleet: &CortexState, name: &str, endpoint: &str) {
|
|||||||
.and_modify(|e| {
|
.and_modify(|e| {
|
||||||
e.status = status;
|
e.status = status;
|
||||||
e.vram_estimate_mb = upstream.vram_used_mb;
|
e.vram_estimate_mb = upstream.vram_used_mb;
|
||||||
|
e.capabilities = upstream.capabilities.clone();
|
||||||
})
|
})
|
||||||
.or_insert_with(|| ModelEntry {
|
.or_insert_with(|| ModelEntry {
|
||||||
id: upstream.id.clone(),
|
id: upstream.id.clone(),
|
||||||
status,
|
status,
|
||||||
last_accessed: None,
|
last_accessed: None,
|
||||||
vram_estimate_mb: upstream.vram_used_mb,
|
vram_estimate_mb: upstream.vram_used_mb,
|
||||||
|
capabilities: upstream.capabilities.clone(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -244,6 +244,7 @@ async fn cold_load(
|
|||||||
status: ModelStatus::Loaded,
|
status: ModelStatus::Loaded,
|
||||||
last_accessed: Some(chrono::Utc::now()),
|
last_accessed: Some(chrono::Utc::now()),
|
||||||
vram_estimate_mb: profile.vram_mb,
|
vram_estimate_mb: profile.vram_mb,
|
||||||
|
capabilities: Vec::new(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -292,7 +293,7 @@ async fn profile_to_spec(
|
|||||||
};
|
};
|
||||||
|
|
||||||
ModelSpec {
|
ModelSpec {
|
||||||
model_id: profile.id.clone(),
|
model_id: qualified_model_id(profile),
|
||||||
harness: profile.harness.clone(),
|
harness: profile.harness.clone(),
|
||||||
quant: profile.quant.clone(),
|
quant: profile.quant.clone(),
|
||||||
tensor_parallel,
|
tensor_parallel,
|
||||||
@@ -300,6 +301,22 @@ async fn profile_to_spec(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Prefix the catalogue id with the scheme when one is declared, so
|
||||||
|
/// neuron resolves the load against the right registry. Without this,
|
||||||
|
/// a profile pointing at the helexa registry would resolve via
|
||||||
|
/// neuron's `default_source` (typically `huggingface`) and fetch
|
||||||
|
/// bytes from the wrong place. Profiles that omit `source` continue
|
||||||
|
/// to pass the bare id through, preserving the pre-Phase-3 contract.
|
||||||
|
///
|
||||||
|
/// Stays at module scope (not nested in `profile_to_spec`) so the unit
|
||||||
|
/// tests can exercise it without spinning up CortexState topology.
|
||||||
|
fn qualified_model_id(profile: &ModelProfile) -> String {
|
||||||
|
match profile.source.as_deref() {
|
||||||
|
Some(scheme) if !scheme.is_empty() => format!("{scheme}:{}", profile.id),
|
||||||
|
_ => profile.id.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Resolve neuron's `/models/{id}/endpoint` to its inference URL and
|
/// Resolve neuron's `/models/{id}/endpoint` to its inference URL and
|
||||||
/// build the final `RouteDecision`. Shared by all three priority
|
/// build the final `RouteDecision`. Shared by all three priority
|
||||||
/// branches above.
|
/// branches above.
|
||||||
@@ -375,7 +392,43 @@ fn rewrite_loopback_host(inference_url: &str, neuron_endpoint: &str) -> Option<S
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::rewrite_loopback_host;
|
use super::{ModelProfile, qualified_model_id, rewrite_loopback_host};
|
||||||
|
|
||||||
|
fn bare_profile(id: &str, source: Option<&str>) -> ModelProfile {
|
||||||
|
ModelProfile {
|
||||||
|
id: id.into(),
|
||||||
|
harness: "candle".into(),
|
||||||
|
quant: None,
|
||||||
|
vram_mb: None,
|
||||||
|
min_devices: 1,
|
||||||
|
min_device_vram_mb: None,
|
||||||
|
pinned_on: vec![],
|
||||||
|
source: source.map(String::from),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn qualified_id_passes_through_when_source_absent() {
|
||||||
|
let p = bare_profile("Qwen/Qwen3-30B", None);
|
||||||
|
assert_eq!(qualified_model_id(&p), "Qwen/Qwen3-30B");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn qualified_id_prefixes_when_source_set() {
|
||||||
|
let p = bare_profile("Helexa/Qwen3.6-27B-Uncensored", Some("helexa"));
|
||||||
|
assert_eq!(
|
||||||
|
qualified_model_id(&p),
|
||||||
|
"helexa:Helexa/Qwen3.6-27B-Uncensored"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn qualified_id_passes_through_when_source_is_empty_string() {
|
||||||
|
// An empty scheme is treated as absent — neuron's default_source
|
||||||
|
// substitution kicks in.
|
||||||
|
let p = bare_profile("Qwen/Qwen3-30B", Some(""));
|
||||||
|
assert_eq!(qualified_model_id(&p), "Qwen/Qwen3-30B");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn rewrites_localhost_keeps_port_and_path() {
|
fn rewrites_localhost_keeps_port_and_path() {
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ async fn test_alias_resolves_in_chat_completions() {
|
|||||||
status: ModelStatus::Loaded,
|
status: ModelStatus::Loaded,
|
||||||
last_accessed: None,
|
last_accessed: None,
|
||||||
vram_estimate_mb: None,
|
vram_estimate_mb: None,
|
||||||
|
capabilities: Vec::new(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -154,6 +155,7 @@ async fn test_aliases_surface_in_v1_models() {
|
|||||||
status: ModelStatus::Loaded,
|
status: ModelStatus::Loaded,
|
||||||
last_accessed: None,
|
last_accessed: None,
|
||||||
vram_estimate_mb: Some(2000),
|
vram_estimate_mb: Some(2000),
|
||||||
|
capabilities: Vec::new(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -235,6 +237,7 @@ async fn test_alias_falls_through_for_unmapped_model() {
|
|||||||
status: ModelStatus::Loaded,
|
status: ModelStatus::Loaded,
|
||||||
last_accessed: None,
|
last_accessed: None,
|
||||||
vram_estimate_mb: None,
|
vram_estimate_mb: None,
|
||||||
|
capabilities: Vec::new(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -305,6 +305,7 @@ pub async fn spawn_gateway_with_state(mock_url: &str) -> (Arc<CortexState>, Stri
|
|||||||
status: ModelStatus::Loaded,
|
status: ModelStatus::Loaded,
|
||||||
last_accessed: None,
|
last_accessed: None,
|
||||||
vram_estimate_mb: Some(8000),
|
vram_estimate_mb: Some(8000),
|
||||||
|
capabilities: Vec::new(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ async fn test_evict_lru_model() {
|
|||||||
status: ModelStatus::Loaded,
|
status: ModelStatus::Loaded,
|
||||||
last_accessed: Some(Utc::now() - chrono::Duration::hours(2)),
|
last_accessed: Some(Utc::now() - chrono::Duration::hours(2)),
|
||||||
vram_estimate_mb: Some(8000),
|
vram_estimate_mb: Some(8000),
|
||||||
|
capabilities: Vec::new(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
node.models.insert(
|
node.models.insert(
|
||||||
@@ -100,6 +101,7 @@ async fn test_evict_lru_model() {
|
|||||||
status: ModelStatus::Loaded,
|
status: ModelStatus::Loaded,
|
||||||
last_accessed: Some(Utc::now()),
|
last_accessed: Some(Utc::now()),
|
||||||
vram_estimate_mb: Some(8000),
|
vram_estimate_mb: Some(8000),
|
||||||
|
capabilities: Vec::new(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -163,6 +165,7 @@ async fn test_eviction_increments_lifecycle_cycles() {
|
|||||||
status: ModelStatus::Loaded,
|
status: ModelStatus::Loaded,
|
||||||
last_accessed: None,
|
last_accessed: None,
|
||||||
vram_estimate_mb: None,
|
vram_estimate_mb: None,
|
||||||
|
capabilities: Vec::new(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -118,6 +118,87 @@ async fn test_poller_updates_gateway_models_endpoint() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_models_endpoint_unions_capabilities_across_nodes() {
|
||||||
|
// C3: two neurons each have the same model loaded but advertise
|
||||||
|
// different capability sets. The gateway's /v1/models must report
|
||||||
|
// the union — a model loaded text-only on one node and
|
||||||
|
// text+vision on another is vision-capable to the fleet.
|
||||||
|
let node_a = common::spawn_mock_neuron_with_models(json!([
|
||||||
|
{"id": "shared-model", "harness": "candle", "status": "loaded", "devices": [0], "vram_used_mb": null, "capabilities": ["text"]}
|
||||||
|
]))
|
||||||
|
.await;
|
||||||
|
let node_b = common::spawn_mock_neuron_with_models(json!([
|
||||||
|
{"id": "shared-model", "harness": "candle", "status": "loaded", "devices": [1], "vram_used_mb": null, "capabilities": ["text", "vision"]}
|
||||||
|
]))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let config = GatewayConfig {
|
||||||
|
gateway: GatewaySettings {
|
||||||
|
listen: "127.0.0.1:0".into(),
|
||||||
|
metrics_listen: "127.0.0.1:0".into(),
|
||||||
|
},
|
||||||
|
eviction: EvictionSettings {
|
||||||
|
strategy: EvictionStrategy::Lru,
|
||||||
|
defrag_after_cycles: 0,
|
||||||
|
},
|
||||||
|
neurons: vec![
|
||||||
|
NeuronEndpoint {
|
||||||
|
name: "node-a".into(),
|
||||||
|
endpoint: node_a,
|
||||||
|
},
|
||||||
|
NeuronEndpoint {
|
||||||
|
name: "node-b".into(),
|
||||||
|
endpoint: node_b,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
models_config: "/dev/null".into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let fleet = Arc::new(CortexState::from_config(&config));
|
||||||
|
cortex_gateway::poller::poll_once(&fleet).await;
|
||||||
|
|
||||||
|
let app = cortex_gateway::build_app(Arc::clone(&fleet));
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let body: serde_json::Value = client
|
||||||
|
.get(format!("http://{addr}/v1/models"))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("request should succeed")
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let model = body["data"]
|
||||||
|
.as_array()
|
||||||
|
.expect("data array")
|
||||||
|
.iter()
|
||||||
|
.find(|m| m["id"] == "shared-model")
|
||||||
|
.expect("shared-model should be present");
|
||||||
|
|
||||||
|
let caps: Vec<&str> = model["capabilities"]
|
||||||
|
.as_array()
|
||||||
|
.expect("capabilities array")
|
||||||
|
.iter()
|
||||||
|
.filter_map(|c| c.as_str())
|
||||||
|
.collect();
|
||||||
|
assert!(caps.contains(&"text"), "union must include text: {caps:?}");
|
||||||
|
assert!(
|
||||||
|
caps.contains(&"vision"),
|
||||||
|
"union must include vision: {caps:?}"
|
||||||
|
);
|
||||||
|
assert_eq!(caps.len(), 2, "union must not duplicate text: {caps:?}");
|
||||||
|
|
||||||
|
// Both nodes hold the model, so two locations regardless of caps.
|
||||||
|
assert_eq!(model["locations"].as_array().unwrap().len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_poller_marks_unreachable_node_unhealthy() {
|
async fn test_poller_marks_unreachable_node_unhealthy() {
|
||||||
let config = GatewayConfig {
|
let config = GatewayConfig {
|
||||||
@@ -216,6 +297,7 @@ async fn test_poller_removes_stale_models() {
|
|||||||
status: ModelStatus::Loaded,
|
status: ModelStatus::Loaded,
|
||||||
last_accessed: None,
|
last_accessed: None,
|
||||||
vram_estimate_mb: None,
|
vram_estimate_mb: None,
|
||||||
|
capabilities: Vec::new(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
node.models.insert(
|
node.models.insert(
|
||||||
@@ -225,6 +307,7 @@ async fn test_poller_removes_stale_models() {
|
|||||||
status: ModelStatus::Loaded,
|
status: ModelStatus::Loaded,
|
||||||
last_accessed: None,
|
last_accessed: None,
|
||||||
vram_estimate_mb: None,
|
vram_estimate_mb: None,
|
||||||
|
capabilities: Vec::new(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -76,20 +76,31 @@ cudarc = { version = "0.19", optional = true, default-features = false, features
|
|||||||
half = { version = "2.5", optional = true }
|
half = { version = "2.5", optional = true }
|
||||||
tokenizers = { version = "0.22", default-features = false, features = ["onig"] }
|
tokenizers = { version = "0.22", default-features = false, features = ["onig"] }
|
||||||
hf-hub = { version = "0.4", features = ["tokio"] }
|
hf-hub = { version = "0.4", features = ["tokio"] }
|
||||||
# Jinja-compatible template renderer for the model's
|
# Jinja-compatible template renderer for the model's chat template
|
||||||
# `tokenizer_config.json::chat_template`. Hugging Face's chat
|
# (standalone `chat_template.jinja` or `tokenizer_config.json::chat_template`).
|
||||||
# templates use a strict subset of Jinja2 that minijinja supports
|
# Hugging Face's chat templates lean on Python string semantics; we
|
||||||
# out of the box. ~80KB compiled; pure Rust, no async surface.
|
# bridge them with `minijinja-contrib`'s `pycompat` callback (str
|
||||||
# Features: `builtins` for the `is defined` / `default` filters HF
|
# methods like `startswith`/`split`/`strip`) plus a `raise_exception`
|
||||||
# templates use; `json` for `tojson` (some Qwen3 templates emit
|
# global. Features: `builtins` for `is defined` / `default`; `json`
|
||||||
# tool definitions via tojson); `serde` so we can hand it a
|
# for `tojson`; `serde` so we can hand it a serde_json::Value context.
|
||||||
# serde_json::Value as the context.
|
|
||||||
minijinja = { version = "2", features = ["builtins", "json", "serde"] }
|
minijinja = { version = "2", features = ["builtins", "json", "serde"] }
|
||||||
|
# Python-compatibility shim: the Qwen3-VL / Qwen3.6 template uses
|
||||||
|
# `content.startswith(...)`, `.endswith(...)`, `.split(...)`,
|
||||||
|
# `.rstrip(...)`, `.lstrip(...)` — Python str methods minijinja doesn't
|
||||||
|
# implement natively. `pycompat::unknown_method_callback` supplies them.
|
||||||
|
minijinja-contrib = { version = "2", features = ["pycompat"] }
|
||||||
# Direct dep on `safetensors` (re-exported by candle but its `TensorView`
|
# Direct dep on `safetensors` (re-exported by candle but its `TensorView`
|
||||||
# / `slice::IndexOp` types are public-but-not-re-exported). Used by the
|
# / `slice::IndexOp` types are public-but-not-re-exported). Used by the
|
||||||
# tp `fused_load` module to read per-rank slices of fused QKV tensors
|
# tp `fused_load` module to read per-rank slices of fused QKV tensors
|
||||||
# without materialising the full tensor on device.
|
# without materialising the full tensor on device.
|
||||||
safetensors = "0.7"
|
safetensors = "0.7"
|
||||||
|
# Vision capability for Qwen3.6 (Stage A of the vision plan in
|
||||||
|
# doc/vision-qwen3_6-spec.md). `image` decodes PNG/JPEG/etc from
|
||||||
|
# the bytes embedded in `data:image/...;base64,...` content parts;
|
||||||
|
# `base64` does the URI decode. Default-features off on `image` to
|
||||||
|
# avoid pulling in audio/video formats we don't need.
|
||||||
|
image = { version = "0.25", default-features = false, features = ["png", "jpeg", "webp", "bmp", "gif"] }
|
||||||
|
base64 = "0.22"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = { workspace = true, features = ["test-util"] }
|
tokio = { workspace = true, features = ["test-util"] }
|
||||||
|
|||||||
@@ -250,6 +250,18 @@ async fn chat_completions(
|
|||||||
})),
|
})),
|
||||||
)
|
)
|
||||||
.into_response(),
|
.into_response(),
|
||||||
|
Err(InferenceError::VisionUnsupported { model_id }) => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(json!({
|
||||||
|
"error": format!(
|
||||||
|
"model '{model_id}' does not support image input"
|
||||||
|
),
|
||||||
|
"code": "vision_unsupported",
|
||||||
|
"model_id": model_id,
|
||||||
|
"suggestion": "load a vision-capable model or remove image_url content parts",
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
Err(InferenceError::Other(e)) => (
|
Err(InferenceError::Other(e)) => (
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json(json!({"error": format!("{e:#}")})),
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
@@ -289,6 +301,18 @@ async fn chat_completions(
|
|||||||
})),
|
})),
|
||||||
)
|
)
|
||||||
.into_response(),
|
.into_response(),
|
||||||
|
Err(InferenceError::VisionUnsupported { model_id }) => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(json!({
|
||||||
|
"error": format!(
|
||||||
|
"model '{model_id}' does not support image input"
|
||||||
|
),
|
||||||
|
"code": "vision_unsupported",
|
||||||
|
"model_id": model_id,
|
||||||
|
"suggestion": "load a vision-capable model or remove image_url content parts",
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
Err(InferenceError::Other(e)) => (
|
Err(InferenceError::Other(e)) => (
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json(json!({"error": format!("{e:#}")})),
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
@@ -452,6 +476,18 @@ fn inference_error_response(err: InferenceError) -> axum::response::Response {
|
|||||||
})),
|
})),
|
||||||
)
|
)
|
||||||
.into_response(),
|
.into_response(),
|
||||||
|
InferenceError::VisionUnsupported { model_id } => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(json!({
|
||||||
|
"error": format!(
|
||||||
|
"model '{model_id}' does not support image input"
|
||||||
|
),
|
||||||
|
"code": "vision_unsupported",
|
||||||
|
"model_id": model_id,
|
||||||
|
"suggestion": "load a vision-capable model or remove image_url content parts",
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
InferenceError::Other(e) => (
|
InferenceError::Other(e) => (
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json(json!({"error": format!("{e:#}")})),
|
Json(json!({"error": format!("{e:#}")})),
|
||||||
|
|||||||
@@ -6,8 +6,18 @@ use figment::{
|
|||||||
providers::{Env, Format, Toml},
|
providers::{Env, Format, Toml},
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
/// Default scheme name applied to bare `org/name` model ids when no
|
||||||
|
/// `[harness.candle.default_source]` is set. Keeps existing operator
|
||||||
|
/// configs (which know nothing about schemes) working unchanged.
|
||||||
|
pub const DEFAULT_SOURCE_SCHEME: &str = "huggingface";
|
||||||
|
|
||||||
|
/// Endpoint URL for the default huggingface source, used when no
|
||||||
|
/// `[harness.candle.sources.huggingface]` is configured.
|
||||||
|
pub const DEFAULT_HF_ENDPOINT: &str = "https://huggingface.co";
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct NeuronConfig {
|
pub struct NeuronConfig {
|
||||||
#[serde(default = "default_port")]
|
#[serde(default = "default_port")]
|
||||||
@@ -37,8 +47,88 @@ pub struct HarnessSettings {
|
|||||||
pub struct CandleHarnessConfig {
|
pub struct CandleHarnessConfig {
|
||||||
/// HuggingFace cache directory for model weights.
|
/// HuggingFace cache directory for model weights.
|
||||||
/// When unset, defers to hf-hub's default (~/.cache/huggingface).
|
/// When unset, defers to hf-hub's default (~/.cache/huggingface).
|
||||||
|
///
|
||||||
|
/// Retained for back-compat — operators with existing
|
||||||
|
/// `hf_cache = "..."` configs continue to work. Treated as the
|
||||||
|
/// `huggingface` source's cache_dir when a sources table isn't
|
||||||
|
/// provided.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub hf_cache: Option<PathBuf>,
|
pub hf_cache: Option<PathBuf>,
|
||||||
|
|
||||||
|
/// Default source scheme applied to bare `org/name` model ids
|
||||||
|
/// (those without an explicit `scheme:` prefix). When unset, falls
|
||||||
|
/// back to `DEFAULT_SOURCE_SCHEME` ("huggingface").
|
||||||
|
#[serde(default)]
|
||||||
|
pub default_source: Option<String>,
|
||||||
|
|
||||||
|
/// Per-scheme source endpoints. Each entry maps a scheme name
|
||||||
|
/// (`huggingface`, `helexa`, an operator's mirror tag, …) to its
|
||||||
|
/// endpoint URL, optional auth env var, and optional cache
|
||||||
|
/// directory.
|
||||||
|
///
|
||||||
|
/// When absent or missing the `huggingface` key, the loader
|
||||||
|
/// synthesises a `huggingface` entry pointing at
|
||||||
|
/// `https://huggingface.co` with `hf_cache` (above) as its
|
||||||
|
/// cache_dir. This keeps single-source configs ergonomic.
|
||||||
|
#[serde(default)]
|
||||||
|
pub sources: HashMap<String, SourceConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Per-scheme source configuration. Mirrors the shape `hf_hub::ApiBuilder`
|
||||||
|
/// needs: endpoint URL, optional auth token (read from an env var so
|
||||||
|
/// secrets stay out of the config file), and optional cache directory
|
||||||
|
/// disambiguated per source to prevent mirror-vs-canonical collisions.
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
|
pub struct SourceConfig {
|
||||||
|
/// Base URL of the registry. Must speak the HF-compatible wire
|
||||||
|
/// format (siblings listing at
|
||||||
|
/// `/api/models/{org}/{name}[/revision/{rev}]`, blob fetch at
|
||||||
|
/// `/{org}/{name}/resolve/{rev}/{path}`).
|
||||||
|
pub endpoint: String,
|
||||||
|
|
||||||
|
/// Environment variable name to read for the bearer token used
|
||||||
|
/// against this source. `None` = anonymous. Reading from env
|
||||||
|
/// (vs. literal token in the config) keeps secrets out of TOML.
|
||||||
|
#[serde(default)]
|
||||||
|
pub auth_env: Option<String>,
|
||||||
|
|
||||||
|
/// Cache directory for this source. The hf-hub
|
||||||
|
/// `models--{org}--{name}/snapshots/...` tree lives directly
|
||||||
|
/// under this path, so distinct sources serving the same
|
||||||
|
/// `org/name` cannot collide on disk.
|
||||||
|
///
|
||||||
|
/// `None` means "share the harness `hf_cache` directory" — only
|
||||||
|
/// safe when the operator has exactly one source configured.
|
||||||
|
#[serde(default)]
|
||||||
|
pub cache_dir: Option<PathBuf>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CandleHarnessConfig {
|
||||||
|
/// Resolve the effective sources map for this config, synthesising
|
||||||
|
/// a `huggingface` entry from legacy fields (`hf_cache`) when the
|
||||||
|
/// operator hasn't supplied a sources table. Idempotent.
|
||||||
|
///
|
||||||
|
/// Returns a fresh map rather than mutating self so the original
|
||||||
|
/// (operator-typed) config can still be serialized back to TOML
|
||||||
|
/// for diagnostics.
|
||||||
|
pub fn effective_sources(&self) -> HashMap<String, SourceConfig> {
|
||||||
|
let mut out = self.sources.clone();
|
||||||
|
out.entry(DEFAULT_SOURCE_SCHEME.to_string())
|
||||||
|
.or_insert_with(|| SourceConfig {
|
||||||
|
endpoint: DEFAULT_HF_ENDPOINT.to_string(),
|
||||||
|
auth_env: Some("HF_TOKEN".to_string()),
|
||||||
|
cache_dir: self.hf_cache.clone(),
|
||||||
|
});
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Effective default scheme. Falls back to `DEFAULT_SOURCE_SCHEME`
|
||||||
|
/// when the operator hasn't pinned one.
|
||||||
|
pub fn effective_default_source(&self) -> &str {
|
||||||
|
self.default_source
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or(DEFAULT_SOURCE_SCHEME)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_port() -> u16 {
|
fn default_port() -> u16 {
|
||||||
@@ -65,3 +155,109 @@ impl Default for NeuronConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn effective_sources_synthesises_huggingface_when_absent() {
|
||||||
|
let cfg = CandleHarnessConfig::default();
|
||||||
|
let sources = cfg.effective_sources();
|
||||||
|
assert!(sources.contains_key("huggingface"));
|
||||||
|
let hf = &sources["huggingface"];
|
||||||
|
assert_eq!(hf.endpoint, DEFAULT_HF_ENDPOINT);
|
||||||
|
assert_eq!(hf.auth_env.as_deref(), Some("HF_TOKEN"));
|
||||||
|
assert!(hf.cache_dir.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn effective_sources_carries_legacy_hf_cache_into_synth_entry() {
|
||||||
|
// Existing operator configs only set `hf_cache = "/archive3/..."`
|
||||||
|
// — the synth must pick that up so the loader keeps using the
|
||||||
|
// operator's storage.
|
||||||
|
let cfg = CandleHarnessConfig {
|
||||||
|
hf_cache: Some(PathBuf::from("/archive3/llm-cache")),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let sources = cfg.effective_sources();
|
||||||
|
assert_eq!(
|
||||||
|
sources["huggingface"].cache_dir.as_deref(),
|
||||||
|
Some(Path::new("/archive3/llm-cache"))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn effective_sources_preserves_explicit_huggingface_entry() {
|
||||||
|
// When an operator types out `[harness.candle.sources.huggingface]`
|
||||||
|
// explicitly, we must not clobber it with the synth defaults.
|
||||||
|
let mut sources = HashMap::new();
|
||||||
|
sources.insert(
|
||||||
|
"huggingface".to_string(),
|
||||||
|
SourceConfig {
|
||||||
|
endpoint: "https://huggingface.example.org".into(),
|
||||||
|
auth_env: Some("MY_TOKEN".into()),
|
||||||
|
cache_dir: Some(PathBuf::from("/operator-cache")),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
let cfg = CandleHarnessConfig {
|
||||||
|
hf_cache: Some(PathBuf::from("/legacy-cache")),
|
||||||
|
sources,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let effective = cfg.effective_sources();
|
||||||
|
assert_eq!(
|
||||||
|
effective["huggingface"].endpoint,
|
||||||
|
"https://huggingface.example.org"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
effective["huggingface"].auth_env.as_deref(),
|
||||||
|
Some("MY_TOKEN")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
effective["huggingface"].cache_dir.as_deref(),
|
||||||
|
Some(Path::new("/operator-cache"))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn effective_sources_includes_helexa_alongside_synth_huggingface() {
|
||||||
|
let mut sources = HashMap::new();
|
||||||
|
sources.insert(
|
||||||
|
"helexa".to_string(),
|
||||||
|
SourceConfig {
|
||||||
|
endpoint: "https://registry.helexa.ai".into(),
|
||||||
|
auth_env: Some("HELEXA_TOKEN".into()),
|
||||||
|
cache_dir: Some(PathBuf::from("/archive3/llm-cache/helexa")),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
let cfg = CandleHarnessConfig {
|
||||||
|
hf_cache: Some(PathBuf::from("/archive3/llm-cache/huggingface")),
|
||||||
|
sources,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let effective = cfg.effective_sources();
|
||||||
|
assert_eq!(effective.len(), 2);
|
||||||
|
assert_eq!(effective["helexa"].endpoint, "https://registry.helexa.ai");
|
||||||
|
// huggingface still gets synth-derived from legacy hf_cache.
|
||||||
|
assert_eq!(
|
||||||
|
effective["huggingface"].cache_dir.as_deref(),
|
||||||
|
Some(Path::new("/archive3/llm-cache/huggingface"))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn effective_default_source_falls_back() {
|
||||||
|
let cfg = CandleHarnessConfig::default();
|
||||||
|
assert_eq!(cfg.effective_default_source(), DEFAULT_SOURCE_SCHEME);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn effective_default_source_honours_explicit() {
|
||||||
|
let cfg = CandleHarnessConfig {
|
||||||
|
default_source: Some("helexa".into()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
assert_eq!(cfg.effective_default_source(), "helexa");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -93,12 +93,13 @@ impl Qwen3_5DecoderLayer {
|
|||||||
&mut self,
|
&mut self,
|
||||||
x: &Tensor,
|
x: &Tensor,
|
||||||
attn_mask: Option<&Tensor>,
|
attn_mask: Option<&Tensor>,
|
||||||
offset: usize,
|
cos: &Tensor,
|
||||||
|
sin: &Tensor,
|
||||||
) -> candle_core::Result<Tensor> {
|
) -> candle_core::Result<Tensor> {
|
||||||
let h = self.input_layernorm.forward(x)?;
|
let h = self.input_layernorm.forward(x)?;
|
||||||
let attn_out = match &mut self.attention {
|
let attn_out = match &mut self.attention {
|
||||||
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?,
|
AttentionKind::Full(attn) => attn.forward(&h, attn_mask, cos, sin)?,
|
||||||
// Linear attention ignores attn_mask + offset; its causal
|
// Linear attention ignores attn_mask + rope; its causal
|
||||||
// structure is baked into the recurrent state lifecycle.
|
// structure is baked into the recurrent state lifecycle.
|
||||||
AttentionKind::Linear(net) => net.forward(&h)?,
|
AttentionKind::Linear(net) => net.forward(&h)?,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -96,7 +96,8 @@ impl Qwen3_5Attention {
|
|||||||
&mut self,
|
&mut self,
|
||||||
x: &Tensor,
|
x: &Tensor,
|
||||||
attn_mask: Option<&Tensor>,
|
attn_mask: Option<&Tensor>,
|
||||||
offset: usize,
|
cos: &Tensor,
|
||||||
|
sin: &Tensor,
|
||||||
) -> candle_core::Result<Tensor> {
|
) -> candle_core::Result<Tensor> {
|
||||||
let (b, l, _) = x.dims3()?;
|
let (b, l, _) = x.dims3()?;
|
||||||
|
|
||||||
@@ -131,8 +132,9 @@ impl Qwen3_5Attention {
|
|||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
.contiguous()?;
|
.contiguous()?;
|
||||||
|
|
||||||
// 3. RoPE on q, k.
|
// 3. RoPE on q, k (cos/sin built once per forward by the model —
|
||||||
let (q, k) = self.rotary.apply(&q, &k, offset)?;
|
// interleaved M-RoPE for image tokens, plain for text).
|
||||||
|
let (q, k) = self.rotary.apply_cos_sin(&q, &k, cos, sin)?;
|
||||||
|
|
||||||
// 4. KV cache.
|
// 4. KV cache.
|
||||||
let (k, v) = self.kv_cache.append(&k, &v)?;
|
let (k, v) = self.kv_cache.append(&k, &v)?;
|
||||||
|
|||||||
@@ -737,6 +737,8 @@ mod tests {
|
|||||||
rope_theta: 10000.0,
|
rope_theta: 10000.0,
|
||||||
partial_rotary_factor: 1.0,
|
partial_rotary_factor: 1.0,
|
||||||
rope_type: None,
|
rope_type: None,
|
||||||
|
mrope_section: Vec::new(),
|
||||||
|
mrope_interleaved: false,
|
||||||
},
|
},
|
||||||
rms_norm_eps: 1e-6,
|
rms_norm_eps: 1e-6,
|
||||||
tie_word_embeddings: false,
|
tie_word_embeddings: false,
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ pub mod linear_attn;
|
|||||||
pub mod mlp;
|
pub mod mlp;
|
||||||
pub mod rmsnorm;
|
pub mod rmsnorm;
|
||||||
pub mod rope;
|
pub mod rope;
|
||||||
|
pub mod vision;
|
||||||
|
|
||||||
use decoder::Qwen3_5DecoderLayer;
|
use decoder::Qwen3_5DecoderLayer;
|
||||||
use rmsnorm::Qwen3_5RmsNorm;
|
use rmsnorm::Qwen3_5RmsNorm;
|
||||||
@@ -99,6 +100,20 @@ pub struct Config {
|
|||||||
pub model_type: String,
|
pub model_type: String,
|
||||||
/// The text-side hyperparameters. Everything we actually need.
|
/// The text-side hyperparameters. Everything we actually need.
|
||||||
pub text_config: TextConfig,
|
pub text_config: TextConfig,
|
||||||
|
/// Vision tower hyperparameters. Present on multimodal
|
||||||
|
/// checkpoints (e.g. Qwen/Qwen3.6-27B); absent on text-only
|
||||||
|
/// variants. When present, `Qwen3_5ForCausalLM::new` loads the
|
||||||
|
/// vision tower alongside the language model so vision-bearing
|
||||||
|
/// requests can splice image embeddings at `<|image_pad|>` token
|
||||||
|
/// positions.
|
||||||
|
#[serde(default)]
|
||||||
|
pub vision_config: Option<vision::VisionConfig>,
|
||||||
|
/// Token id the chat template emits per image patch group.
|
||||||
|
/// Mirrors the LM tokenizer's `<|image_pad|>` id (248056 for
|
||||||
|
/// Qwen3.6). The runtime locates these in the prompt and splices
|
||||||
|
/// in `VisionTower::forward` output. `None` for text-only models.
|
||||||
|
#[serde(default)]
|
||||||
|
pub image_token_id: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Inner config (the `text_config` block). Mirrors the Qwen3 layout
|
/// Inner config (the `text_config` block). Mirrors the Qwen3 layout
|
||||||
@@ -176,11 +191,12 @@ fn default_hidden_act() -> String {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Nested `rope_parameters` block from a Qwen3-Next `config.json`.
|
/// Nested `rope_parameters` block from a Qwen3-Next `config.json`.
|
||||||
/// `mrope_section` and `mrope_interleaved` are accepted via the
|
///
|
||||||
/// `#[serde(default)]` flatten-tolerance below but ignored — we treat
|
/// For text-only inference the three MRoPE position grids carry
|
||||||
/// MRoPE as plain RoPE for text-only inference (the three position
|
/// identical ids, so the interleave is a no-op and plain RoPE applies.
|
||||||
/// grids carry identical ids when there's no vision input, so the
|
/// For vision inputs `mrope_section` + `mrope_interleaved` drive the
|
||||||
/// interleaving is a no-op).
|
/// per-axis (text/height/width) rotary used by image tokens — see
|
||||||
|
/// `rope.rs`.
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
pub struct RopeParameters {
|
pub struct RopeParameters {
|
||||||
/// Base for the inverse-frequency computation. Qwen3.6: 10_000_000.
|
/// Base for the inverse-frequency computation. Qwen3.6: 10_000_000.
|
||||||
@@ -196,6 +212,16 @@ pub struct RopeParameters {
|
|||||||
/// implemented here.
|
/// implemented here.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub rope_type: Option<String>,
|
pub rope_type: Option<String>,
|
||||||
|
/// MRoPE per-axis section sizes `[text, height, width]` — e.g.
|
||||||
|
/// `[11, 11, 10]` for Qwen3.6, summing to the rotary half-dim.
|
||||||
|
/// Empty for models that don't declare MRoPE (→ plain RoPE).
|
||||||
|
#[serde(default)]
|
||||||
|
pub mrope_section: Vec<usize>,
|
||||||
|
/// Whether the three MRoPE axes are interleaved per-frequency
|
||||||
|
/// (Qwen3-VL / Qwen3.6 style, `true`) rather than block-concatenated
|
||||||
|
/// (Qwen2-VL style, `false`).
|
||||||
|
#[serde(default)]
|
||||||
|
pub mrope_interleaved: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_rope_theta() -> f64 {
|
fn default_rope_theta() -> f64 {
|
||||||
@@ -206,6 +232,80 @@ fn default_partial_rotary_factor() -> f32 {
|
|||||||
1.0
|
1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Splice rows from `img` into `h` at `positions`. Stage B helper.
|
||||||
|
///
|
||||||
|
/// `h`: `(1, L, hidden)` — the LM's input embedding tensor after
|
||||||
|
/// `embed_tokens.forward`.
|
||||||
|
/// `img`: `(N_img, hidden)` — image embeddings, one row per
|
||||||
|
/// `<|image_pad|>` token in the prompt. Must already be in `h.dtype()`.
|
||||||
|
/// `positions`: indices into the `L` axis where image rows go;
|
||||||
|
/// `positions.len() == N_img`.
|
||||||
|
///
|
||||||
|
/// Approach: group `positions` into contiguous runs (because the chat
|
||||||
|
/// template emits `<|vision_start|><|image_pad|>×N<|vision_end|>` —
|
||||||
|
/// the pad tokens for each image land in one contiguous span), then
|
||||||
|
/// `slice_assign` per run. For typical Qwen3.6 requests this is one
|
||||||
|
/// or two runs per image; `slice_assign` does one tensor copy per
|
||||||
|
/// run, which is cheap relative to the decoder forward pass.
|
||||||
|
pub(crate) fn splice_runs(
|
||||||
|
h: &Tensor,
|
||||||
|
img: &Tensor,
|
||||||
|
positions: &[u32],
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
debug_assert!(
|
||||||
|
!positions.is_empty(),
|
||||||
|
"splice_runs precondition: non-empty positions"
|
||||||
|
);
|
||||||
|
let hidden = h.dim(2)?;
|
||||||
|
let mut out = h.clone();
|
||||||
|
let mut img_offset = 0_usize;
|
||||||
|
let mut run_start = positions[0] as usize;
|
||||||
|
let mut run_end_exclusive = run_start + 1;
|
||||||
|
for &p in &positions[1..] {
|
||||||
|
let p = p as usize;
|
||||||
|
if p == run_end_exclusive {
|
||||||
|
run_end_exclusive = p + 1;
|
||||||
|
} else {
|
||||||
|
apply_run(
|
||||||
|
&mut out,
|
||||||
|
img,
|
||||||
|
&mut img_offset,
|
||||||
|
run_start,
|
||||||
|
run_end_exclusive,
|
||||||
|
hidden,
|
||||||
|
)?;
|
||||||
|
run_start = p;
|
||||||
|
run_end_exclusive = p + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
apply_run(
|
||||||
|
&mut out,
|
||||||
|
img,
|
||||||
|
&mut img_offset,
|
||||||
|
run_start,
|
||||||
|
run_end_exclusive,
|
||||||
|
hidden,
|
||||||
|
)?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_run(
|
||||||
|
out: &mut Tensor,
|
||||||
|
img: &Tensor,
|
||||||
|
img_offset: &mut usize,
|
||||||
|
run_start: usize,
|
||||||
|
run_end_exclusive: usize,
|
||||||
|
hidden: usize,
|
||||||
|
) -> candle_core::Result<()> {
|
||||||
|
let run_len = run_end_exclusive - run_start;
|
||||||
|
let slice = img
|
||||||
|
.narrow(0, *img_offset, run_len)?
|
||||||
|
.reshape((1, run_len, hidden))?;
|
||||||
|
*out = out.slice_assign(&[0..1, run_start..run_end_exclusive, 0..hidden], &slice)?;
|
||||||
|
*img_offset += run_len;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Qwen3-Next base transformer (embedding + decoder stack + final
|
/// Qwen3-Next base transformer (embedding + decoder stack + final
|
||||||
/// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can
|
/// norm). Public so a TP variant in `harness/tp/tp_qwen3_5.rs` can
|
||||||
/// also build on it later — for now only `Qwen3_5ForCausalLM` is the
|
/// also build on it later — for now only `Qwen3_5ForCausalLM` is the
|
||||||
@@ -214,6 +314,16 @@ pub struct Qwen3_5Model {
|
|||||||
embed_tokens: Embedding,
|
embed_tokens: Embedding,
|
||||||
layers: Vec<Qwen3_5DecoderLayer>,
|
layers: Vec<Qwen3_5DecoderLayer>,
|
||||||
norm: Qwen3_5RmsNorm,
|
norm: Qwen3_5RmsNorm,
|
||||||
|
/// Shared with every full-attention layer; the model uses it to
|
||||||
|
/// build the per-forward cos/sin (interleaved M-RoPE for image
|
||||||
|
/// tokens, plain for text) once, which the layers then apply.
|
||||||
|
rotary: Arc<RotaryEmbedding>,
|
||||||
|
/// `offset + rope_delta` is the text-axis position during decode.
|
||||||
|
/// 0 for text-only; set from `get_rope_index` during a vision
|
||||||
|
/// prefill (image tokens compress the position space, so text after
|
||||||
|
/// the image resumes from a smaller counter than the sequence
|
||||||
|
/// index). Reset in `clear_kv_cache`.
|
||||||
|
rope_delta: i64,
|
||||||
device: Device,
|
device: Device,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
}
|
}
|
||||||
@@ -265,6 +375,8 @@ impl Qwen3_5Model {
|
|||||||
embed_tokens,
|
embed_tokens,
|
||||||
layers,
|
layers,
|
||||||
norm,
|
norm,
|
||||||
|
rotary,
|
||||||
|
rope_delta: 0,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
})
|
})
|
||||||
@@ -278,6 +390,9 @@ impl Qwen3_5Model {
|
|||||||
for l in &mut self.layers {
|
for l in &mut self.layers {
|
||||||
l.clear_kv_cache();
|
l.clear_kv_cache();
|
||||||
}
|
}
|
||||||
|
// New request → no image-compressed position offset until the
|
||||||
|
// next vision prefill sets one.
|
||||||
|
self.rope_delta = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
@@ -289,8 +404,98 @@ impl Qwen3_5Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
self.forward_inner(input, offset, None, None, &[])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Forward with image-embedding splice. Stage B of the vision plan.
|
||||||
|
///
|
||||||
|
/// `input_ids`: `(1, L)` token ids — same shape the text-only
|
||||||
|
/// `forward` accepts (single-batch; multi-batch vision is not in
|
||||||
|
/// scope today).
|
||||||
|
/// `image_embeds`: `(N_image_tokens, hidden_size)` — concatenation
|
||||||
|
/// of every image's post-merger embedding (`VisionTower::forward`
|
||||||
|
/// output), in the same order images appear in the input. The
|
||||||
|
/// caller has already done the per-image patch-count expansion of
|
||||||
|
/// `<|image_pad|>` tokens in `input_ids`, so `N_image_tokens`
|
||||||
|
/// equals the number of `image_token_id` positions in `input_ids`.
|
||||||
|
/// `image_token_id`: the sentinel token (e.g. 248056 for Qwen3.6).
|
||||||
|
///
|
||||||
|
/// The splice replaces the LM's text-side embedding at each
|
||||||
|
/// `image_token_id` position with the corresponding row from
|
||||||
|
/// `image_embeds`. After the splice the decoder runs the interleaved
|
||||||
|
/// M-RoPE path: `grids` carries each image's post-merge LM grid
|
||||||
|
/// `(lm_gh, lm_gw)` so `get_rope_index` assigns image tokens their 2D
|
||||||
|
/// coordinates (dynamic resolution, #14).
|
||||||
|
pub fn forward_with_vision(
|
||||||
|
&mut self,
|
||||||
|
input_ids: &Tensor,
|
||||||
|
offset: usize,
|
||||||
|
image_embeds: &Tensor,
|
||||||
|
image_token_id: u32,
|
||||||
|
grids: &[(usize, usize)],
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
self.forward_inner(
|
||||||
|
input_ids,
|
||||||
|
offset,
|
||||||
|
Some(image_embeds),
|
||||||
|
Some(image_token_id),
|
||||||
|
grids,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward_inner(
|
||||||
|
&mut self,
|
||||||
|
input: &Tensor,
|
||||||
|
offset: usize,
|
||||||
|
image_embeds: Option<&Tensor>,
|
||||||
|
image_token_id: Option<u32>,
|
||||||
|
grids: &[(usize, usize)],
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
let (b, l) = input.dims2()?;
|
let (b, l) = input.dims2()?;
|
||||||
let mut h = self.embed_tokens.forward(input)?;
|
let mut h = self.embed_tokens.forward(input)?;
|
||||||
|
|
||||||
|
// Vision path: splice image embeddings at `image_token_id`
|
||||||
|
// positions and build interleaved M-RoPE cos/sin so image tokens
|
||||||
|
// carry their 2D (lm_gh × lm_gw) grid coordinates. Text / decode skip the
|
||||||
|
// device→host id copy entirely and take the plain-RoPE fast path
|
||||||
|
// — bit-for-bit the pre-M-RoPE behaviour when `rope_delta == 0`.
|
||||||
|
let (cos, sin) = if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) {
|
||||||
|
// Token ids on CPU — reused for the splice + position ids.
|
||||||
|
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
|
||||||
|
|
||||||
|
let mut positions: Vec<u32> = Vec::with_capacity(img.dim(0)?);
|
||||||
|
for (idx, id) in ids.iter().enumerate() {
|
||||||
|
if *id == tok_id {
|
||||||
|
positions.push(idx as u32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let n_img_tokens = img.dim(0)?;
|
||||||
|
if positions.len() != n_img_tokens {
|
||||||
|
candle_core::bail!(
|
||||||
|
"forward_with_vision: prompt has {} image-token positions but \
|
||||||
|
image_embeds carries {} tokens — call build_prompt_for_request to \
|
||||||
|
ensure the per-image patch-count expansion has been applied",
|
||||||
|
positions.len(),
|
||||||
|
n_img_tokens,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if !positions.is_empty() {
|
||||||
|
// Cast image_embeds to the LM's dtype, then splice the
|
||||||
|
// contiguous `<|image_pad|>` runs in place.
|
||||||
|
let img = img.to_dtype(self.dtype)?;
|
||||||
|
h = splice_runs(&h, &img, &positions)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let (text, height, width, delta) = rope::get_rope_index(&ids, tok_id, grids)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?;
|
||||||
|
self.rope_delta = delta;
|
||||||
|
let pos = rope::mrope_position_tensor(&text, &height, &width, &self.device)?;
|
||||||
|
self.rotary.mrope_cos_sin(&pos)?
|
||||||
|
} else {
|
||||||
|
let base = (offset as i64 + self.rope_delta).max(0) as usize;
|
||||||
|
self.rotary.plain_cos_sin(base, l)?
|
||||||
|
};
|
||||||
|
|
||||||
// Causal mask only needed for L > 1 prefill; full-attention
|
// Causal mask only needed for L > 1 prefill; full-attention
|
||||||
// layers consume it via broadcast_add. Linear-attention layers
|
// layers consume it via broadcast_add. Linear-attention layers
|
||||||
// ignore the mask.
|
// ignore the mask.
|
||||||
@@ -300,7 +505,7 @@ impl Qwen3_5Model {
|
|||||||
Some(self.causal_mask(b, l, offset)?)
|
Some(self.causal_mask(b, l, offset)?)
|
||||||
};
|
};
|
||||||
for layer in &mut self.layers {
|
for layer in &mut self.layers {
|
||||||
h = layer.forward(&h, causal.as_ref(), offset)?;
|
h = layer.forward(&h, causal.as_ref(), &cos, &sin)?;
|
||||||
}
|
}
|
||||||
self.norm.forward(&h)
|
self.norm.forward(&h)
|
||||||
}
|
}
|
||||||
@@ -309,6 +514,15 @@ impl Qwen3_5Model {
|
|||||||
pub struct Qwen3_5ForCausalLM {
|
pub struct Qwen3_5ForCausalLM {
|
||||||
base: Qwen3_5Model,
|
base: Qwen3_5Model,
|
||||||
lm_head: Linear,
|
lm_head: Linear,
|
||||||
|
/// Vision tower (Stage A4). `None` for text-only checkpoints or
|
||||||
|
/// when the operator has opted out. When present, the harness's
|
||||||
|
/// `Job::EncodeImage` dispatch path runs `vision.forward(image)`
|
||||||
|
/// and the LM forward (Stage B) splices the result at
|
||||||
|
/// `image_token_id` positions in the input embedding stream.
|
||||||
|
vision: Option<vision::VisionTower>,
|
||||||
|
/// Mirrors `Config::image_token_id`. Cached here so the runtime
|
||||||
|
/// doesn't have to round-trip through the parsed config struct.
|
||||||
|
image_token_id: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Qwen3_5ForCausalLM {
|
impl Qwen3_5ForCausalLM {
|
||||||
@@ -324,7 +538,52 @@ impl Qwen3_5ForCausalLM {
|
|||||||
.with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?;
|
.with_context(|| format!("load '{}/lm_head/weight'", vb.prefix()))?;
|
||||||
Linear::new(weight, None)
|
Linear::new(weight, None)
|
||||||
};
|
};
|
||||||
Ok(Self { base, lm_head })
|
// Stage A4: load the vision tower when the config carries a
|
||||||
|
// `vision_config` block and the safetensors actually carry
|
||||||
|
// `model.visual.*` weights. The `Option<VisionConfig>` on the
|
||||||
|
// config makes this a single-source-of-truth decision —
|
||||||
|
// text-only checkpoints just leave `vision_config` unset and
|
||||||
|
// get `None` here without any extra plumbing.
|
||||||
|
let vision = if let Some(vcfg) = config.vision_config.clone() {
|
||||||
|
tracing::info!(
|
||||||
|
depth = vcfg.depth,
|
||||||
|
hidden_size = vcfg.hidden_size,
|
||||||
|
"loading qwen3_5 vision tower"
|
||||||
|
);
|
||||||
|
Some(
|
||||||
|
vision::VisionTower::load(vcfg, vb.pp("model.visual"))
|
||||||
|
.context("load qwen3_5 vision tower (model.visual.*)")?,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
base,
|
||||||
|
lm_head,
|
||||||
|
vision,
|
||||||
|
image_token_id: config.image_token_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// True when this checkpoint loaded a vision tower. Used by the
|
||||||
|
/// HTTP layer to advertise vision capability in `/v1/models` and
|
||||||
|
/// to reject image-bearing requests against text-only loads with
|
||||||
|
/// a clean 400.
|
||||||
|
pub fn has_vision(&self) -> bool {
|
||||||
|
self.vision.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Vision tower handle, if loaded. The device-worker
|
||||||
|
/// `EncodeImage` job dispatches to `vision.forward(image)`.
|
||||||
|
pub fn vision(&self) -> Option<&vision::VisionTower> {
|
||||||
|
self.vision.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `<|image_pad|>` token id from `config.json`, when known.
|
||||||
|
/// The Stage B prompt-builder uses this to count expansion targets
|
||||||
|
/// and the LM forward uses it to locate splice positions.
|
||||||
|
pub fn image_token_id(&self) -> Option<u32> {
|
||||||
|
self.image_token_id
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `input`: token-id tensor of shape `(B, L)`. Returns logits at
|
/// `input`: token-id tensor of shape `(B, L)`. Returns logits at
|
||||||
@@ -337,6 +596,25 @@ impl Qwen3_5ForCausalLM {
|
|||||||
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Stage B: forward with image-embedding splice. Mirrors `forward`
|
||||||
|
/// but routes through `Qwen3_5Model::forward_with_vision` so the
|
||||||
|
/// LM's input embeddings get the image patches spliced in at
|
||||||
|
/// `image_token_id` positions before the decoder stack runs.
|
||||||
|
pub fn forward_with_vision(
|
||||||
|
&mut self,
|
||||||
|
input: &Tensor,
|
||||||
|
offset: usize,
|
||||||
|
image_embeds: &Tensor,
|
||||||
|
image_token_id: u32,
|
||||||
|
grids: &[(usize, usize)],
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let (_, l) = input.dims2()?;
|
||||||
|
let hidden =
|
||||||
|
self.base
|
||||||
|
.forward_with_vision(input, offset, image_embeds, image_token_id, grids)?;
|
||||||
|
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn clear_kv_cache(&mut self) {
|
pub fn clear_kv_cache(&mut self) {
|
||||||
self.base.clear_kv_cache();
|
self.base.clear_kv_cache();
|
||||||
}
|
}
|
||||||
@@ -394,4 +672,50 @@ mod tests {
|
|||||||
assert_eq!(cfg.text_config.rope_parameters.rope_theta, 10_000_000.0);
|
assert_eq!(cfg.text_config.rope_parameters.rope_theta, 10_000_000.0);
|
||||||
assert!((cfg.text_config.rope_parameters.partial_rotary_factor - 0.25).abs() < 1e-6);
|
assert!((cfg.text_config.rope_parameters.partial_rotary_factor - 0.25).abs() < 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// `splice_runs` replaces (1, L, H) embedding rows at the given
|
||||||
|
/// positions with rows from a (N_img, H) image-embedding tensor,
|
||||||
|
/// in the order positions are supplied.
|
||||||
|
#[test]
|
||||||
|
fn splice_runs_replaces_at_contiguous_positions() {
|
||||||
|
use candle_core::{DType, Device};
|
||||||
|
|
||||||
|
let dev = Device::Cpu;
|
||||||
|
// (1, L=5, H=2) text embeddings — encoded as floats so the
|
||||||
|
// assertion can spot the change without dtype conversion.
|
||||||
|
let h_vals: Vec<f32> = vec![
|
||||||
|
10., 11., // pos 0
|
||||||
|
20., 21., // pos 1
|
||||||
|
30., 31., // pos 2
|
||||||
|
40., 41., // pos 3
|
||||||
|
50., 51., // pos 4
|
||||||
|
];
|
||||||
|
let h = Tensor::from_vec(h_vals, (1, 5, 2), &dev).unwrap();
|
||||||
|
|
||||||
|
// Two image embeddings to splice at positions 1 and 2 (a
|
||||||
|
// contiguous run — single image emitting two patch tokens).
|
||||||
|
let img_vals: Vec<f32> = vec![-1., -2., -3., -4.];
|
||||||
|
let img = Tensor::from_vec(img_vals, (2, 2), &dev).unwrap();
|
||||||
|
|
||||||
|
let out = splice_runs(&h, &img, &[1, 2]).unwrap();
|
||||||
|
let flat: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
|
||||||
|
assert_eq!(flat, vec![10., 11., -1., -2., -3., -4., 40., 41., 50., 51.]);
|
||||||
|
let _ = DType::F32;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Non-contiguous positions: two images at positions [1] and [3]
|
||||||
|
/// each contributing one patch. `splice_runs` should iterate
|
||||||
|
/// runs and place the corresponding image rows.
|
||||||
|
#[test]
|
||||||
|
fn splice_runs_handles_non_contiguous_runs() {
|
||||||
|
use candle_core::Device;
|
||||||
|
let dev = Device::Cpu;
|
||||||
|
let h_vals: Vec<f32> = vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.];
|
||||||
|
let h = Tensor::from_vec(h_vals, (1, 5, 2), &dev).unwrap();
|
||||||
|
let img_vals: Vec<f32> = vec![-1., -2., -3., -4.];
|
||||||
|
let img = Tensor::from_vec(img_vals, (2, 2), &dev).unwrap();
|
||||||
|
let out = splice_runs(&h, &img, &[1, 3]).unwrap();
|
||||||
|
let flat: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
|
||||||
|
assert_eq!(flat, vec![1., 1., -1., -2., 3., 3., -3., -4., 5., 5.]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,19 +1,27 @@
|
|||||||
//! Rotary position embedding for Qwen3-Next's full-attention layers.
|
//! Rotary position embedding for Qwen3-Next's full-attention layers.
|
||||||
//!
|
//!
|
||||||
//! Qwen3.6 ships with MRoPE (multimodal RoPE) machinery in the
|
//! Qwen3.6 declares **interleaved M-RoPE** (multimodal RoPE): the
|
||||||
//! reference Python — three position grids interleaved per
|
//! rotary half-dimension is split across three position axes —
|
||||||
//! `mrope_section`. For text-only inference all three grids carry the
|
//! `[text, height, width]` per `mrope_section` (`[11,11,10]` for
|
||||||
//! same position ids and the interleave is a no-op, so this module
|
//! Qwen3.6) — interleaved per-frequency. For **text** every token's
|
||||||
//! implements the plain (non-mrope) flavour: the standard inv_freq
|
//! three axes carry the same position id, so the interleave is a no-op
|
||||||
//! cosine/sine tables driven by `rope_theta` and `head_dim`.
|
//! and this reduces exactly to plain RoPE. For **image** tokens the
|
||||||
|
//! height/width axes carry the patch's 2D grid coordinates, which is
|
||||||
|
//! how the model reads the 14×14 patch layout (without it, all patches
|
||||||
|
//! share a height position and the image reads as vertical repetition).
|
||||||
//!
|
//!
|
||||||
//! Rotation flavour: **GLM-style** rotate-half (the second half of the
|
//! Two cos/sin builders feed a shared [`RotaryEmbedding::apply`]:
|
||||||
//! head dim is negated and swapped into the first). The reference
|
//! - [`RotaryEmbedding::plain_cos_sin`] narrows the precomputed tables
|
||||||
//! Python uses `apply_rotary_pos_emb` with `rotate_half`; candle's
|
//! at a scalar position — the text / decode fast path.
|
||||||
//! `rope_slow` is the matching helper.
|
//! - [`RotaryEmbedding::mrope_cos_sin`] builds per-token cos/sin from a
|
||||||
|
//! `(3, seq)` position-id tensor, blending the three axes' frequencies
|
||||||
|
//! at the interleave index sets — the vision-prefill path.
|
||||||
|
//!
|
||||||
|
//! Rotation flavour: **GLM-style** rotate-half (candle's `rope_slow`),
|
||||||
|
//! matching the reference Python's `apply_rotary_pos_emb` + `rotate_half`.
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle_core::{DType, Device, Tensor};
|
use candle_core::{DType, Device, IndexOp, Tensor};
|
||||||
|
|
||||||
use super::TextConfig;
|
use super::TextConfig;
|
||||||
|
|
||||||
@@ -21,6 +29,18 @@ use super::TextConfig;
|
|||||||
pub struct RotaryEmbedding {
|
pub struct RotaryEmbedding {
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
|
/// Inverse frequencies, shape `(1, rotary_dim/2)`. Retained (beyond
|
||||||
|
/// the precomputed `sin`/`cos` tables) so [`Self::mrope_cos_sin`] can
|
||||||
|
/// build cos/sin from arbitrary per-axis position ids.
|
||||||
|
inv_freq: Tensor,
|
||||||
|
/// Per-axis column masks over the rotary half-dim, shape `(1, half)`,
|
||||||
|
/// f32 0/1. `mask_t + mask_h + mask_w` partitions the columns; a
|
||||||
|
/// column belongs to exactly one axis. For a non-MRoPE config
|
||||||
|
/// `mask_t` is all-ones and the others all-zero (→ plain RoPE).
|
||||||
|
mask_t: Tensor,
|
||||||
|
mask_h: Tensor,
|
||||||
|
mask_w: Tensor,
|
||||||
|
dtype: DType,
|
||||||
/// Number of dims at the head's leading edge that the rotation
|
/// Number of dims at the head's leading edge that the rotation
|
||||||
/// covers. The remaining `head_dim - rotary_dim` dims pass through
|
/// covers. The remaining `head_dim - rotary_dim` dims pass through
|
||||||
/// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so
|
/// unchanged. Qwen3-Next uses `partial_rotary_factor = 0.25`, so
|
||||||
@@ -29,6 +49,52 @@ pub struct RotaryEmbedding {
|
|||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build the per-axis 0/1 column masks over the rotary half-dim from
|
||||||
|
/// `mrope_section`. Returns `(temporal, height, width)` each length
|
||||||
|
/// `half`. Temporal is the complement of height ∪ width, so the three
|
||||||
|
/// masks always partition `0..half` and reduce to all-temporal (plain
|
||||||
|
/// RoPE) when no usable section is given.
|
||||||
|
fn mrope_masks(
|
||||||
|
half: usize,
|
||||||
|
section: &[usize],
|
||||||
|
interleaved: bool,
|
||||||
|
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
|
||||||
|
let mut mh = vec![0f32; half];
|
||||||
|
let mut mw = vec![0f32; half];
|
||||||
|
if section.len() == 3 {
|
||||||
|
if interleaved {
|
||||||
|
// Qwen3-VL: height at columns 1,4,7,… ; width at 2,5,8,… ;
|
||||||
|
// temporal keeps 0,3,6,… — each `take`n from `mrope_section`.
|
||||||
|
for i in (1..half).step_by(3).take(section[1]) {
|
||||||
|
mh[i] = 1.0;
|
||||||
|
}
|
||||||
|
for i in (2..half).step_by(3).take(section[2]) {
|
||||||
|
mw[i] = 1.0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Qwen2-VL: contiguous blocks [text | height | width].
|
||||||
|
let h_start = section[0].min(half);
|
||||||
|
let h_end = (section[0] + section[1]).min(half);
|
||||||
|
for m in mh.iter_mut().take(h_end).skip(h_start) {
|
||||||
|
*m = 1.0;
|
||||||
|
}
|
||||||
|
for m in mw.iter_mut().take(half).skip(h_end) {
|
||||||
|
*m = 1.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mt: Vec<f32> = (0..half)
|
||||||
|
.map(|i| {
|
||||||
|
if mh[i] == 0.0 && mw[i] == 0.0 {
|
||||||
|
1.0
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
(mt, mh, mw)
|
||||||
|
}
|
||||||
|
|
||||||
impl RotaryEmbedding {
|
impl RotaryEmbedding {
|
||||||
pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> {
|
pub fn new(dtype: DType, cfg: &TextConfig, dev: &Device) -> Result<Self> {
|
||||||
let head_dim = cfg.head_dim;
|
let head_dim = cfg.head_dim;
|
||||||
@@ -52,44 +118,88 @@ impl RotaryEmbedding {
|
|||||||
.step_by(2)
|
.step_by(2)
|
||||||
.map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32)
|
.map(|i| 1f32 / rope.rope_theta.powf(i as f64 / rotary_dim as f64) as f32)
|
||||||
.collect();
|
.collect();
|
||||||
let n = inv_freq.len();
|
let half = inv_freq.len();
|
||||||
let inv_freq = Tensor::from_vec(inv_freq, (1, n), dev)?.to_dtype(DType::F32)?;
|
let inv_freq = Tensor::from_vec(inv_freq, (1, half), dev)?.to_dtype(DType::F32)?;
|
||||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.reshape((max_seq_len, 1))?;
|
.reshape((max_seq_len, 1))?;
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
|
||||||
|
// MRoPE axis masks. `sum(mrope_section)` should equal `half`;
|
||||||
|
// warn-tolerant: any shortfall just stays on the temporal axis.
|
||||||
|
let (mt, mh, mw) = mrope_masks(half, &rope.mrope_section, rope.mrope_interleaved);
|
||||||
|
let mask_t = Tensor::from_vec(mt, (1, half), dev)?;
|
||||||
|
let mask_h = Tensor::from_vec(mh, (1, half), dev)?;
|
||||||
|
let mask_w = Tensor::from_vec(mw, (1, half), dev)?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
sin: freqs.sin()?.to_dtype(dtype)?,
|
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||||
cos: freqs.cos()?.to_dtype(dtype)?,
|
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||||
|
inv_freq,
|
||||||
|
mask_t,
|
||||||
|
mask_h,
|
||||||
|
mask_w,
|
||||||
|
dtype,
|
||||||
rotary_dim,
|
rotary_dim,
|
||||||
head_dim,
|
head_dim,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Apply RoPE to q, k.
|
/// cos/sin for a contiguous run of `seq_len` positions starting at
|
||||||
///
|
/// `pos`, by narrowing the precomputed tables. The text / decode
|
||||||
/// `q`, `k` shape: `(B, H, L, head_dim)`. `offset` is the index
|
/// path (all three MRoPE axes equal → plain RoPE). Shape
|
||||||
/// into the cached cos/sin table — the position of the first token
|
/// `(seq_len, rotary_dim/2)`.
|
||||||
/// in the current step.
|
pub fn plain_cos_sin(
|
||||||
///
|
&self,
|
||||||
/// When `rotary_dim < head_dim` the rotation is applied only to the
|
pos: usize,
|
||||||
/// first `rotary_dim` dims of each head; the tail passes through
|
seq_len: usize,
|
||||||
/// unchanged (matches the reference Python's
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
/// `apply_rotary_pos_emb` with non-trivial `partial_rotary_factor`).
|
let cos = self.cos.narrow(0, pos, seq_len)?;
|
||||||
pub fn apply(
|
let sin = self.sin.narrow(0, pos, seq_len)?;
|
||||||
|
Ok((cos, sin))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// cos/sin from explicit per-token 3D position ids, shape
|
||||||
|
/// `(3, seq_len)` (axes: text, height, width). Builds each axis's
|
||||||
|
/// frequencies and blends them at the interleave index sets, so
|
||||||
|
/// every rotary frequency slot is driven by exactly one axis.
|
||||||
|
/// Reduces exactly to [`Self::plain_cos_sin`] when the three axes are
|
||||||
|
/// equal. Returns cos/sin of shape `(seq_len, rotary_dim/2)`.
|
||||||
|
pub fn mrope_cos_sin(&self, position_ids: &Tensor) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let pos = position_ids.to_dtype(DType::F32)?;
|
||||||
|
let (axes, seq_len) = pos.dims2()?;
|
||||||
|
debug_assert_eq!(axes, 3, "mrope position_ids must have 3 axes");
|
||||||
|
// Per-axis freqs: pos[a] (seq,1) @ inv_freq (1,half) → (seq,half).
|
||||||
|
let ft = pos.i(0)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
|
||||||
|
let fh = pos.i(1)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
|
||||||
|
let fw = pos.i(2)?.reshape((seq_len, 1))?.matmul(&self.inv_freq)?;
|
||||||
|
// Blend: each column belongs to exactly one axis (masks partition
|
||||||
|
// the half-dim), so this picks the right axis per frequency slot.
|
||||||
|
let blended = ft
|
||||||
|
.broadcast_mul(&self.mask_t)?
|
||||||
|
.add(&fh.broadcast_mul(&self.mask_h)?)?
|
||||||
|
.add(&fw.broadcast_mul(&self.mask_w)?)?;
|
||||||
|
let cos = blended.cos()?.to_dtype(self.dtype)?;
|
||||||
|
let sin = blended.sin()?.to_dtype(self.dtype)?;
|
||||||
|
Ok((cos, sin))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply rotary to `q`, `k` (shape `(B, H, L, head_dim)`) using
|
||||||
|
/// precomputed `cos`/`sin` of shape `(L, rotary_dim/2)`. Partial
|
||||||
|
/// rotary: only the first `rotary_dim` dims rotate; the tail passes
|
||||||
|
/// through unchanged.
|
||||||
|
pub fn apply_cos_sin(
|
||||||
&self,
|
&self,
|
||||||
q: &Tensor,
|
q: &Tensor,
|
||||||
k: &Tensor,
|
k: &Tensor,
|
||||||
offset: usize,
|
cos: &Tensor,
|
||||||
|
sin: &Tensor,
|
||||||
) -> candle_core::Result<(Tensor, Tensor)> {
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
let (_, _, seq_len, head_dim_in) = q.dims4()?;
|
let (_, _, _seq_len, head_dim_in) = q.dims4()?;
|
||||||
debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch");
|
debug_assert_eq!(head_dim_in, self.head_dim, "q head_dim mismatch");
|
||||||
let cos = self.cos.narrow(0, offset, seq_len)?;
|
|
||||||
let sin = self.sin.narrow(0, offset, seq_len)?;
|
|
||||||
if self.rotary_dim == self.head_dim {
|
if self.rotary_dim == self.head_dim {
|
||||||
// Full rotation.
|
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, cos, sin)?;
|
||||||
let q_embed = candle_nn::rotary_emb::rope_slow(&q.contiguous()?, &cos, &sin)?;
|
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, cos, sin)?;
|
||||||
let k_embed = candle_nn::rotary_emb::rope_slow(&k.contiguous()?, &cos, &sin)?;
|
|
||||||
Ok((q_embed, k_embed))
|
Ok((q_embed, k_embed))
|
||||||
} else {
|
} else {
|
||||||
// Partial rotation: narrow → rotate → cat the untouched tail.
|
// Partial rotation: narrow → rotate → cat the untouched tail.
|
||||||
@@ -102,8 +212,8 @@ impl RotaryEmbedding {
|
|||||||
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
|
.narrow(candle_core::D::Minus1, 0, self.rotary_dim)?
|
||||||
.contiguous()?;
|
.contiguous()?;
|
||||||
let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
|
let k_pass = k.narrow(candle_core::D::Minus1, self.rotary_dim, tail)?;
|
||||||
let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, &cos, &sin)?;
|
let q_rotated = candle_nn::rotary_emb::rope_slow(&q_rot, cos, sin)?;
|
||||||
let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, &cos, &sin)?;
|
let k_rotated = candle_nn::rotary_emb::rope_slow(&k_rot, cos, sin)?;
|
||||||
let q_embed =
|
let q_embed =
|
||||||
Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?;
|
Tensor::cat(&[&q_rotated, &q_pass.contiguous()?], candle_core::D::Minus1)?;
|
||||||
let k_embed =
|
let k_embed =
|
||||||
@@ -112,3 +222,358 @@ impl RotaryEmbedding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Compute interleaved-M-RoPE 3D position ids for a full prompt that may
|
||||||
|
/// contain image-placeholder runs, plus the decode `rope_delta`.
|
||||||
|
///
|
||||||
|
/// Mirrors the reference `get_rope_index`:
|
||||||
|
/// - text tokens advance a single running counter `c`, all three axes
|
||||||
|
/// equal (`[c, c, c]`);
|
||||||
|
/// - each contiguous run of `image_token_id` is one image; its tokens get
|
||||||
|
/// `[base + t, base + h, base + w]` in row-major (t outer, h, w inner),
|
||||||
|
/// where `base` is the counter at the run's start; after the run the
|
||||||
|
/// counter resumes from `base + max(grid_t, grid_h, grid_w)`.
|
||||||
|
///
|
||||||
|
/// Returns `(text_pos, height_pos, width_pos, rope_delta)`, each pos `Vec`
|
||||||
|
/// length `input_ids.len()`. `rope_delta = final_counter - seq_len`: add it
|
||||||
|
/// to a plain decode offset so text resumes from the counter after the
|
||||||
|
/// (position-compressed) image blocks.
|
||||||
|
///
|
||||||
|
/// Whether interleaved M-RoPE for image tokens is enabled. Default
|
||||||
|
/// **on** — Qwen3.6 was trained with interleaved M-RoPE, and this
|
||||||
|
/// implementation matches the HF `apply_interleaved_mrope` /
|
||||||
|
/// `get_rope_index` reference exactly (verified column-for-column). The
|
||||||
|
/// env var is a **kill switch**: `NEURON_MROPE=0` falls back to plain
|
||||||
|
/// sequential positions for image tokens (the pre-M-RoPE behaviour).
|
||||||
|
pub(crate) fn mrope_enabled() -> bool {
|
||||||
|
std::env::var("NEURON_MROPE")
|
||||||
|
.map(|v| {
|
||||||
|
!matches!(
|
||||||
|
v.trim().to_ascii_lowercase().as_str(),
|
||||||
|
"0" | "false" | "no" | "off"
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.unwrap_or(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Position ids for the forward path. Gated by [`mrope_enabled`]: when
|
||||||
|
/// off, returns plain sequential identity positions on all three axes
|
||||||
|
/// (`mrope_cos_sin` then reduces exactly to plain RoPE), restoring the
|
||||||
|
/// pre-M-RoPE behaviour without touching the rest of the forward.
|
||||||
|
pub(crate) fn get_rope_index(
|
||||||
|
input_ids: &[u32],
|
||||||
|
image_token_id: u32,
|
||||||
|
grids: &[(usize, usize)],
|
||||||
|
) -> Result<MRopeIndex> {
|
||||||
|
if !mrope_enabled() {
|
||||||
|
let seq: Vec<i64> = (0..input_ids.len() as i64).collect();
|
||||||
|
return Ok((seq.clone(), seq.clone(), seq, 0));
|
||||||
|
}
|
||||||
|
compute_mrope_index(input_ids, image_token_id, grids)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The real interleaved-M-RoPE position-id computation (always active in
|
||||||
|
/// unit tests; gated behind [`get_rope_index`] at runtime).
|
||||||
|
///
|
||||||
|
/// `grids` carries the post-merge LM grid `(lm_gh, lm_gw)` for each image
|
||||||
|
/// run, in prompt order — a run length alone cannot recover its
|
||||||
|
/// factorisation, so the grids must be passed (#14 dynamic resolution).
|
||||||
|
/// Each image is a still frame (`grid_t = 1`); its tokens get
|
||||||
|
/// `[base, base + hh, base + ww]` row-major and the shared counter
|
||||||
|
/// resumes at `base + max(lm_gh, lm_gw)`. Multi-image is correct because
|
||||||
|
/// the counter threads across images and interleaved text.
|
||||||
|
pub(crate) fn compute_mrope_index(
|
||||||
|
input_ids: &[u32],
|
||||||
|
image_token_id: u32,
|
||||||
|
grids: &[(usize, usize)],
|
||||||
|
) -> Result<MRopeIndex> {
|
||||||
|
let n = input_ids.len();
|
||||||
|
let mut text = Vec::with_capacity(n);
|
||||||
|
let mut height = Vec::with_capacity(n);
|
||||||
|
let mut width = Vec::with_capacity(n);
|
||||||
|
let mut counter: i64 = 0;
|
||||||
|
let mut i = 0;
|
||||||
|
let mut k = 0; // index into `grids`, one per image run
|
||||||
|
while i < n {
|
||||||
|
if input_ids[i] == image_token_id {
|
||||||
|
let start = i;
|
||||||
|
while i < n && input_ids[i] == image_token_id {
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
let run = i - start;
|
||||||
|
let (grid_h, grid_w) = *grids.get(k).ok_or_else(|| {
|
||||||
|
anyhow::anyhow!(
|
||||||
|
"get_rope_index: image run #{k} (len {run}) has no matching grid \
|
||||||
|
({} grids supplied)",
|
||||||
|
grids.len()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
k += 1;
|
||||||
|
if grid_h * grid_w != run {
|
||||||
|
anyhow::bail!(
|
||||||
|
"get_rope_index: image run #{} length {run} != grid {grid_h}×{grid_w} = {}",
|
||||||
|
k - 1,
|
||||||
|
grid_h * grid_w
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let base = counter;
|
||||||
|
for hh in 0..grid_h {
|
||||||
|
for ww in 0..grid_w {
|
||||||
|
text.push(base); // grid_t = 1 → temporal axis const
|
||||||
|
height.push(base + hh as i64);
|
||||||
|
width.push(base + ww as i64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
counter = base + grid_h.max(grid_w) as i64;
|
||||||
|
} else {
|
||||||
|
text.push(counter);
|
||||||
|
height.push(counter);
|
||||||
|
width.push(counter);
|
||||||
|
counter += 1;
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if k != grids.len() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"get_rope_index: prompt has {k} image run(s) but {} grid(s) were supplied",
|
||||||
|
grids.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let delta = counter - n as i64;
|
||||||
|
Ok((text, height, width, delta))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `(text_pos, height_pos, width_pos, rope_delta)` returned by
|
||||||
|
/// [`get_rope_index`]; the three vectors combine into the `(3, seq)`
|
||||||
|
/// MRoPE position-id tensor.
|
||||||
|
pub(crate) type MRopeIndex = (Vec<i64>, Vec<i64>, Vec<i64>, i64);
|
||||||
|
|
||||||
|
/// Build the `(3, seq)` position-id tensor consumed by
|
||||||
|
/// [`RotaryEmbedding::mrope_cos_sin`] from the three axis vectors.
|
||||||
|
///
|
||||||
|
/// Built directly as **f32** (positions are small integers, exact in
|
||||||
|
/// f32 well past any context length): the freqs matmul needs float
|
||||||
|
/// anyway, and this avoids an i64 tensor / i64→f32 cast on the GPU.
|
||||||
|
pub(crate) fn mrope_position_tensor(
|
||||||
|
text: &[i64],
|
||||||
|
height: &[i64],
|
||||||
|
width: &[i64],
|
||||||
|
dev: &Device,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let seq = text.len();
|
||||||
|
let mut flat = Vec::with_capacity(3 * seq);
|
||||||
|
flat.extend(text.iter().map(|&x| x as f32));
|
||||||
|
flat.extend(height.iter().map(|&x| x as f32));
|
||||||
|
flat.extend(width.iter().map(|&x| x as f32));
|
||||||
|
Tensor::from_vec(flat, (3, seq), dev)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use candle_core::IndexOp;
|
||||||
|
|
||||||
|
/// A TextConfig stub with Qwen3.6's rope params (head_dim 256,
|
||||||
|
/// partial 0.25 → rotary_dim 64 → half 32; section [11,11,10]).
|
||||||
|
fn qwen36_cfg() -> TextConfig {
|
||||||
|
serde_json::from_value(serde_json::json!({
|
||||||
|
"hidden_size": 5120,
|
||||||
|
"num_hidden_layers": 1,
|
||||||
|
"num_attention_heads": 64,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"head_dim": 256,
|
||||||
|
"intermediate_size": 1,
|
||||||
|
"vocab_size": 10,
|
||||||
|
"rms_norm_eps": 1e-6,
|
||||||
|
"max_position_embeddings": 64,
|
||||||
|
"layer_types": ["full_attention"],
|
||||||
|
"rope_parameters": {
|
||||||
|
"rope_theta": 10000000.0,
|
||||||
|
"partial_rotary_factor": 0.25,
|
||||||
|
"mrope_section": [11, 11, 10],
|
||||||
|
"mrope_interleaved": true
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.expect("cfg")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn mrope_masks_partition_the_half_dim() {
|
||||||
|
let (mt, mh, mw) = mrope_masks(32, &[11, 11, 10], true);
|
||||||
|
// Each column belongs to exactly one axis.
|
||||||
|
for i in 0..32 {
|
||||||
|
let s = mt[i] + mh[i] + mw[i];
|
||||||
|
assert_eq!(s, 1.0, "column {i} covered {s} times");
|
||||||
|
}
|
||||||
|
assert_eq!(mt.iter().sum::<f32>(), 11.0);
|
||||||
|
assert_eq!(mh.iter().sum::<f32>(), 11.0);
|
||||||
|
assert_eq!(mw.iter().sum::<f32>(), 10.0);
|
||||||
|
// Interleave: temporal 0,3,…; height 1,4,…; width 2,5,…
|
||||||
|
assert_eq!(mt[0], 1.0);
|
||||||
|
assert_eq!(mh[1], 1.0);
|
||||||
|
assert_eq!(mw[2], 1.0);
|
||||||
|
assert_eq!(mt[3], 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The load-bearing invariant: when all three position axes are
|
||||||
|
/// equal (text), `mrope_cos_sin` must reproduce `plain_cos_sin`
|
||||||
|
/// bit-for-bit — i.e. M-RoPE is a no-op for text, so text inference
|
||||||
|
/// is unchanged.
|
||||||
|
#[test]
|
||||||
|
fn mrope_reduces_to_plain_for_equal_axes() {
|
||||||
|
let dev = Device::Cpu;
|
||||||
|
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
|
||||||
|
|
||||||
|
// positions 5,6,7 on all three axes.
|
||||||
|
let base: Vec<i64> = vec![5, 6, 7];
|
||||||
|
let pos =
|
||||||
|
Tensor::from_vec([base.clone(), base.clone(), base].concat(), (3, 3), &dev).unwrap();
|
||||||
|
|
||||||
|
let (mc, ms) = rope.mrope_cos_sin(&pos).unwrap();
|
||||||
|
let (pc, ps) = rope.plain_cos_sin(5, 3).unwrap();
|
||||||
|
|
||||||
|
let dcos = (mc - pc).unwrap().abs().unwrap().max_all().unwrap();
|
||||||
|
let dsin = (ms - ps).unwrap().abs().unwrap().max_all().unwrap();
|
||||||
|
assert!(
|
||||||
|
dcos.to_scalar::<f32>().unwrap() < 1e-6,
|
||||||
|
"cos mismatch {dcos:?}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
dsin.to_scalar::<f32>().unwrap() < 1e-6,
|
||||||
|
"sin mismatch {dsin:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hand-checked interleave: a width-axis column (index 2) must track
|
||||||
|
/// the WIDTH position, while a temporal column (index 0) tracks the
|
||||||
|
/// TEXT position, even when the axes differ.
|
||||||
|
#[test]
|
||||||
|
fn mrope_blends_axes_at_interleave_columns() {
|
||||||
|
let dev = Device::Cpu;
|
||||||
|
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
|
||||||
|
let half = rope.inv_freq.dim(1).unwrap();
|
||||||
|
let inv: Vec<f32> = rope.inv_freq.i(0).unwrap().to_vec1().unwrap();
|
||||||
|
|
||||||
|
// One token: text=10, height=3, width=7 — all distinct.
|
||||||
|
let pos = Tensor::from_vec(vec![10i64, 3, 7], (3, 1), &dev).unwrap();
|
||||||
|
let (cos, _sin) = rope.mrope_cos_sin(&pos).unwrap();
|
||||||
|
let cos_row: Vec<f32> = cos.i(0).unwrap().to_vec1().unwrap();
|
||||||
|
assert_eq!(cos_row.len(), half);
|
||||||
|
|
||||||
|
// Column 0 (temporal) → text pos 10. Column 1 (height) → 3.
|
||||||
|
// Column 2 (width) → 7.
|
||||||
|
assert!((cos_row[0] - (10.0 * inv[0]).cos()).abs() < 1e-5);
|
||||||
|
assert!((cos_row[1] - (3.0 * inv[1]).cos()).abs() < 1e-5);
|
||||||
|
assert!((cos_row[2] - (7.0 * inv[2]).cos()).abs() < 1e-5);
|
||||||
|
assert!((cos_row[3] - (10.0 * inv[3]).cos()).abs() < 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn get_rope_index_text_only_is_sequential() {
|
||||||
|
let (t, h, w, delta) = compute_mrope_index(&[1, 2, 3, 4], 99, &[]).unwrap();
|
||||||
|
assert_eq!(t, vec![0, 1, 2, 3]);
|
||||||
|
assert_eq!(h, vec![0, 1, 2, 3]);
|
||||||
|
assert_eq!(w, vec![0, 1, 2, 3]);
|
||||||
|
assert_eq!(delta, 0, "no image → delta 0 → plain decode positions");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn get_rope_index_text_image_text() {
|
||||||
|
// [text, image(2x2 run of 4), text]. image_token = 99, grid (2,2).
|
||||||
|
let ids = [1u32, 99, 99, 99, 99, 2];
|
||||||
|
let (t, h, w, delta) = compute_mrope_index(&ids, 99, &[(2, 2)]).unwrap();
|
||||||
|
// token 0: text → 0. image base=1, grid 2x2:
|
||||||
|
// t all = 1; h = base+row = [1,1,2,2]; w = base+col = [1,2,1,2].
|
||||||
|
// resume from base + max(2,2) = 3. trailing text → 3.
|
||||||
|
assert_eq!(t, vec![0, 1, 1, 1, 1, 3]);
|
||||||
|
assert_eq!(h, vec![0, 1, 1, 2, 2, 3]);
|
||||||
|
assert_eq!(w, vec![0, 1, 2, 1, 2, 3]);
|
||||||
|
// final counter = 4, seq_len = 6 → delta = -2 (the 4 image tokens
|
||||||
|
// advanced the counter by only 2).
|
||||||
|
assert_eq!(delta, -2);
|
||||||
|
// Decode after the prompt (offset = 6) → text position 6 + (-2) = 4.
|
||||||
|
assert_eq!(6 + delta, 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn get_rope_index_nonsquare_single_image() {
|
||||||
|
// text + image(2 rows × 3 cols = 6 tokens). grid (2,3).
|
||||||
|
let ids = [1u32, 99, 99, 99, 99, 99, 99];
|
||||||
|
let (t, h, w, delta) = compute_mrope_index(&ids, 99, &[(2, 3)]).unwrap();
|
||||||
|
// base = 1; row-major h = [0,0,0,1,1,1]+1, w = [0,1,2,0,1,2]+1.
|
||||||
|
assert_eq!(t, vec![0, 1, 1, 1, 1, 1, 1]);
|
||||||
|
assert_eq!(h, vec![0, 1, 1, 1, 2, 2, 2]);
|
||||||
|
assert_eq!(w, vec![0, 1, 2, 3, 1, 2, 3]);
|
||||||
|
// resume from base + max(2,3) = 4; seq_len 7, counter 4 → delta -3.
|
||||||
|
assert_eq!(delta, 4 - 7);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn get_rope_index_two_images_different_grids() {
|
||||||
|
// img(2x2)=4, text, img(1x3)=3. grids [(2,2),(1,3)].
|
||||||
|
let ids = [99, 99, 99, 99, 7, 99, 99, 99];
|
||||||
|
let (t, h, w, delta) = compute_mrope_index(&ids, 99, &[(2, 2), (1, 3)]).unwrap();
|
||||||
|
// img1 base=0 → t=0, h=[0,0,1,1], w=[0,1,0,1]; resume max(2,2)=2.
|
||||||
|
// text at counter 2. img2 base=3 → t=3, h=[3,3,3], w=[3,4,5];
|
||||||
|
// resume 3+max(1,3)=6.
|
||||||
|
assert_eq!(t, vec![0, 0, 0, 0, 2, 3, 3, 3]);
|
||||||
|
assert_eq!(h, vec![0, 0, 1, 1, 2, 3, 3, 3]);
|
||||||
|
assert_eq!(w, vec![0, 1, 0, 1, 2, 3, 4, 5]);
|
||||||
|
assert_eq!(delta, 6 - 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn get_rope_index_on_by_default() {
|
||||||
|
// With NEURON_MROPE unset (default ON), the runtime path returns
|
||||||
|
// the real interleaved-M-RoPE positions. (NEURON_MROPE=0 would fall
|
||||||
|
// back to identity; not asserted here since it depends on env.)
|
||||||
|
let (t, h, w, _delta) = get_rope_index(&[1, 99, 99, 99, 99, 2], 99, &[(2, 2)]).unwrap();
|
||||||
|
assert_eq!(t, vec![0, 1, 1, 1, 1, 3]);
|
||||||
|
assert_eq!(h, vec![0, 1, 1, 2, 2, 3]);
|
||||||
|
assert_eq!(w, vec![0, 1, 2, 1, 2, 3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn get_rope_index_grid_mismatches_error() {
|
||||||
|
// run length != grid product.
|
||||||
|
assert!(compute_mrope_index(&[99u32; 6], 99, &[(2, 2)]).is_err());
|
||||||
|
// too few grids for the number of image runs.
|
||||||
|
assert!(compute_mrope_index(&[99, 99, 7, 99], 99, &[(1, 2)]).is_err());
|
||||||
|
// too many grids.
|
||||||
|
assert!(compute_mrope_index(&[99, 99], 99, &[(1, 2), (1, 1)]).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn position_tensor_round_trips_through_mrope_cos_sin() {
|
||||||
|
// get_rope_index → (3,seq) tensor → mrope_cos_sin, and confirm an
|
||||||
|
// image token's height column tracks its grid row (not the text
|
||||||
|
// counter), i.e. the end-to-end position plumbing is wired right.
|
||||||
|
let dev = Device::Cpu;
|
||||||
|
let rope = RotaryEmbedding::new(DType::F32, &qwen36_cfg(), &dev).unwrap();
|
||||||
|
let ids = [1u32, 99, 99, 99, 99]; // text + 2x2 image
|
||||||
|
let (t, h, w, _d) = compute_mrope_index(&ids, 99, &[(2, 2)]).unwrap();
|
||||||
|
let pos = mrope_position_tensor(&t, &h, &w, &dev).unwrap();
|
||||||
|
assert_eq!(pos.dims(), &[3, 5]);
|
||||||
|
let (cos, _sin) = rope.mrope_cos_sin(&pos).unwrap();
|
||||||
|
assert_eq!(cos.dims(), &[5, rope.inv_freq.dim(1).unwrap()]);
|
||||||
|
|
||||||
|
let inv: Vec<f32> = rope.inv_freq.i(0).unwrap().to_vec1().unwrap();
|
||||||
|
// Last image token (index 4): grid (h=1, w=1) → base 1 → h=2, w=2.
|
||||||
|
// Height column (index 1) must track h-position 2, not text.
|
||||||
|
let last: Vec<f32> = cos.i(4).unwrap().to_vec1().unwrap();
|
||||||
|
assert!((last[1] - (2.0 * inv[1]).cos()).abs() < 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn get_rope_index_196_is_14x14() {
|
||||||
|
let mut ids = vec![1u32]; // one text token
|
||||||
|
ids.extend(std::iter::repeat_n(99u32, 196));
|
||||||
|
let (t, h, w, _delta) = compute_mrope_index(&ids, 99, &[(14, 14)]).unwrap();
|
||||||
|
// image base = 1. Last image token (index 196) is grid (h=13,w=13).
|
||||||
|
assert_eq!(*t.last().unwrap(), 1, "grid_t=1 → temporal const at base");
|
||||||
|
assert_eq!(h[1], 1, "first image row at base");
|
||||||
|
assert_eq!(w[1], 1, "first image col at base");
|
||||||
|
assert_eq!(h[196], 1 + 13, "last image row = base + 13");
|
||||||
|
assert_eq!(w[196], 1 + 13, "last image col = base + 13");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
835
crates/neuron/src/harness/arch/qwen3_5/vision.rs
Normal file
835
crates/neuron/src/harness/arch/qwen3_5/vision.rs
Normal file
@@ -0,0 +1,835 @@
|
|||||||
|
//! Qwen3.6 vision tower.
|
||||||
|
//!
|
||||||
|
//! 27 pre-norm ViT blocks with **LayerNorm** (with biases — not the
|
||||||
|
//! `(1+w)·x` RmsNorm the language model uses), fused QKV attention,
|
||||||
|
//! GELU-tanh MLP. Followed by a `merger` that LayerNorms each
|
||||||
|
//! 1152-dim vision token, spatially 2×2-merges them into 4608-dim
|
||||||
|
//! groups, and projects to the LM's 5120-dim hidden via
|
||||||
|
//! `linear_fc1 → GELU → linear_fc2`.
|
||||||
|
//!
|
||||||
|
//! Architecture spec sourced from beast's cached Qwen3.6-27B
|
||||||
|
//! safetensors header (Stage A0, see
|
||||||
|
//! `doc/vision-qwen3_6-spec.md`). All weight shapes confirmed
|
||||||
|
//! from the live `.safetensors` headers, not inferred.
|
||||||
|
//!
|
||||||
|
//! **Conv3d wrinkle.** The published `patch_embed.proj.weight` is 5D
|
||||||
|
//! `[1152, 3, 2, 16, 16]` — a 3D conv with kernel
|
||||||
|
//! `(t=2, h=16, w=16)`. Candle 0.10 has no Conv3d. For static images
|
||||||
|
//! we get away with a trick: when the temporal patch size is 2 and we
|
||||||
|
//! duplicate the still image along the temporal axis (`T = 2`,
|
||||||
|
//! frame_0 == frame_1), the Conv3d output equals a Conv2d run with
|
||||||
|
//! the *sum* of the two temporal weight slices:
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! output = W_0 · frame_0 + W_1 · frame_1 + bias
|
||||||
|
//! = (W_0 + W_1) · frame + bias (static image)
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! So at load we sum-collapse the temporal axis and use a 4D
|
||||||
|
//! `Conv2d` kernel. Video support would have to do the real Conv3d
|
||||||
|
//! (different frames mean the trick fails) — tracked alongside the
|
||||||
|
//! dynamic-resolution work in issue #14.
|
||||||
|
//!
|
||||||
|
//! Forward signature (Stage A — no LM splice yet):
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! fn forward(&self, image: &Tensor) -> Result<Tensor>
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! `image` is `(3, H, W)` f32, normalised by `preprocess::preprocess`.
|
||||||
|
//! Returns `(N_lm_tokens, out_hidden_size)` post-merger tokens ready
|
||||||
|
//! to splice into the LM's input embeddings at `<|image_pad|>`
|
||||||
|
//! positions. For Qwen3.6 at 448×448 → 28×28 patches → 14×14 = 196
|
||||||
|
//! LM tokens of dim 5120.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use candle_core::{D, DType, Device, IndexOp, Module, Tensor};
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder;
|
||||||
|
use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear};
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
fn env_truthy(name: &str) -> bool {
|
||||||
|
std::env::var(name)
|
||||||
|
.map(|v| {
|
||||||
|
matches!(
|
||||||
|
v.trim().to_ascii_lowercase().as_str(),
|
||||||
|
"1" | "true" | "yes" | "on"
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Legacy escape hatch: when set, use the original Stage-A sequential
|
||||||
|
/// `pos_embed` lookup instead of the bilinear grid interpolation.
|
||||||
|
/// Default off (interpolation on) — for A/B comparison only.
|
||||||
|
fn vision_legacy_pos() -> bool {
|
||||||
|
env_truthy("NEURON_VISION_LEGACY_POS")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Legacy escape hatch: when set, skip the 2D vision rotary in the ViT
|
||||||
|
/// attention (the original Stage-A behaviour). Default off (rotary on)
|
||||||
|
/// — for A/B comparison only.
|
||||||
|
fn vision_legacy_rope() -> bool {
|
||||||
|
env_truthy("NEURON_VISION_LEGACY_ROPE")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Qwen3.6 vision tower hyperparameters. Mirrors the `vision_config`
|
||||||
|
/// block of `config.json`. Only the fields we actually need are
|
||||||
|
/// captured; serde tolerates the rest.
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct VisionConfig {
|
||||||
|
/// Number of ViT blocks (`depth: 27` for Qwen3.6).
|
||||||
|
pub depth: usize,
|
||||||
|
/// Vision-token dimension throughout the tower (1152 for Qwen3.6).
|
||||||
|
pub hidden_size: usize,
|
||||||
|
/// MLP intermediate dim (4304).
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
/// Attention head count (16). `head_dim = hidden_size / num_heads`.
|
||||||
|
pub num_heads: usize,
|
||||||
|
/// Number of slots in the learned position embedding (2304).
|
||||||
|
/// Caps the maximum image patch count.
|
||||||
|
pub num_position_embeddings: usize,
|
||||||
|
/// Spatial patch edge in pixels (16).
|
||||||
|
pub patch_size: usize,
|
||||||
|
/// Temporal kernel depth in the patch embed (2 for Qwen3.6 — we
|
||||||
|
/// collapse this into a single Conv2d for static-image inference;
|
||||||
|
/// see the module-level Conv3d wrinkle).
|
||||||
|
pub temporal_patch_size: usize,
|
||||||
|
/// Patches grouped per LM token by the merger (2 → 2×2 = 4
|
||||||
|
/// patches per LM token).
|
||||||
|
pub spatial_merge_size: usize,
|
||||||
|
/// Vision input channels (3, RGB).
|
||||||
|
pub in_channels: usize,
|
||||||
|
/// Merger output dim — matches the LM's `hidden_size` (5120 for
|
||||||
|
/// Qwen3.6). The merger projects from vision dim → LM dim.
|
||||||
|
pub out_hidden_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
const LAYER_NORM_EPS: f64 = 1e-6;
|
||||||
|
/// Number of LM tokens emitted by the merger per vision-token group.
|
||||||
|
const LM_TOKENS_PER_MERGE_GROUP: usize = 1;
|
||||||
|
|
||||||
|
/// One ViT block: pre-LN → attn → residual; pre-LN → MLP → residual.
|
||||||
|
struct VisionBlock {
|
||||||
|
norm1: LayerNorm,
|
||||||
|
qkv: Linear,
|
||||||
|
proj: Linear,
|
||||||
|
norm2: LayerNorm,
|
||||||
|
fc1: Linear,
|
||||||
|
fc2: Linear,
|
||||||
|
num_heads: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VisionBlock {
|
||||||
|
fn load(cfg: &VisionConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||||
|
let h = cfg.hidden_size;
|
||||||
|
let head_dim = h / cfg.num_heads;
|
||||||
|
let norm1 = layer_norm(vb.pp("norm1"), h)?;
|
||||||
|
let qkv = linear(vb.pp("attn.qkv"), h, 3 * h)?;
|
||||||
|
let proj = linear(vb.pp("attn.proj"), h, h)?;
|
||||||
|
let norm2 = layer_norm(vb.pp("norm2"), h)?;
|
||||||
|
let fc1 = linear(vb.pp("mlp.linear_fc1"), h, cfg.intermediate_size)?;
|
||||||
|
let fc2 = linear(vb.pp("mlp.linear_fc2"), cfg.intermediate_size, h)?;
|
||||||
|
Ok(Self {
|
||||||
|
norm1,
|
||||||
|
qkv,
|
||||||
|
proj,
|
||||||
|
norm2,
|
||||||
|
fc1,
|
||||||
|
fc2,
|
||||||
|
num_heads: cfg.num_heads,
|
||||||
|
head_dim,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `x`: `(N, hidden_size)` un-batched. `rotary`: optional
|
||||||
|
/// `(cos, sin)` each `(N, head_dim/2)` — the 2D vision rotary applied
|
||||||
|
/// to q/k. Returns same shape.
|
||||||
|
fn forward(&self, x: &Tensor, rotary: Option<&(Tensor, Tensor)>) -> Result<Tensor> {
|
||||||
|
let attn_in = self.norm1.forward(x)?;
|
||||||
|
let attn_out = self.attention(&attn_in, rotary)?;
|
||||||
|
let x = x.add(&attn_out)?;
|
||||||
|
let mlp_in = self.norm2.forward(&x)?;
|
||||||
|
let mlp_out = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&mlp_in)?)?)?;
|
||||||
|
x.add(&mlp_out).map_err(Into::into)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Multi-head self-attention over the patch sequence. No causal
|
||||||
|
/// mask — every patch attends to every other patch. When `rotary` is
|
||||||
|
/// given, the 2D vision rotary (row/col position) is applied to q, k
|
||||||
|
/// before the scores, matching HF `apply_rotary_pos_emb_vision`
|
||||||
|
/// (`rope_slow` is the same rotate-half form).
|
||||||
|
fn attention(&self, x: &Tensor, rotary: Option<&(Tensor, Tensor)>) -> Result<Tensor> {
|
||||||
|
let (n, hidden) = x.dims2()?;
|
||||||
|
// qkv: (N, 3*hidden). Split into Q, K, V each (N, hidden).
|
||||||
|
let qkv = self.qkv.forward(x)?;
|
||||||
|
let qkv = qkv.reshape((n, 3, self.num_heads, self.head_dim))?;
|
||||||
|
// Transpose to (3, num_heads, N, head_dim) for per-head views.
|
||||||
|
let qkv = qkv.permute((1, 2, 0, 3))?.contiguous()?;
|
||||||
|
let q = qkv.i(0)?;
|
||||||
|
let k = qkv.i(1)?;
|
||||||
|
let v = qkv.i(2)?;
|
||||||
|
// 2D vision rotary on q, k (full head_dim; rotate-half form).
|
||||||
|
let (q, k) = match rotary {
|
||||||
|
Some((cos, sin)) => {
|
||||||
|
let q = candle_nn::rotary_emb::rope_slow(&q.unsqueeze(0)?, cos, sin)?.squeeze(0)?;
|
||||||
|
let k = candle_nn::rotary_emb::rope_slow(&k.unsqueeze(0)?, cos, sin)?.squeeze(0)?;
|
||||||
|
(q, k)
|
||||||
|
}
|
||||||
|
None => (q, k),
|
||||||
|
};
|
||||||
|
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||||
|
// (num_heads, N, head_dim) @ (num_heads, head_dim, N) -> (num_heads, N, N)
|
||||||
|
let scores = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
|
||||||
|
let scores = (scores * scale)?;
|
||||||
|
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||||
|
// (num_heads, N, N) @ (num_heads, N, head_dim) -> (num_heads, N, head_dim)
|
||||||
|
let out = probs.matmul(&v)?;
|
||||||
|
// Merge heads back: (N, num_heads, head_dim) -> (N, hidden).
|
||||||
|
let out = out.permute((1, 0, 2))?.contiguous()?.reshape((n, hidden))?;
|
||||||
|
self.proj.forward(&out).map_err(Into::into)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `merger`: LayerNorm per token → spatial 2×2 merge (concat 4
|
||||||
|
/// adjacent tokens into one 4608-dim vector) → fc1 → GELU-tanh →
|
||||||
|
/// fc2. Output dim is the LM's hidden_size.
|
||||||
|
struct VisionMerger {
|
||||||
|
norm: LayerNorm,
|
||||||
|
fc1: Linear,
|
||||||
|
fc2: Linear,
|
||||||
|
merge_input_dim: usize,
|
||||||
|
spatial_merge_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VisionMerger {
|
||||||
|
fn load(cfg: &VisionConfig, vb: &ShardedVarBuilder) -> Result<Self> {
|
||||||
|
let h = cfg.hidden_size;
|
||||||
|
let merge = cfg.spatial_merge_size;
|
||||||
|
let merge_input_dim = h * merge * merge;
|
||||||
|
let norm = layer_norm(vb.pp("norm"), h)?;
|
||||||
|
let fc1 = linear(vb.pp("linear_fc1"), merge_input_dim, merge_input_dim)?;
|
||||||
|
let fc2 = linear(vb.pp("linear_fc2"), merge_input_dim, cfg.out_hidden_size)?;
|
||||||
|
Ok(Self {
|
||||||
|
norm,
|
||||||
|
fc1,
|
||||||
|
fc2,
|
||||||
|
merge_input_dim,
|
||||||
|
spatial_merge_size: merge,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `tokens`: `(grid_h, grid_w, hidden_size)`. The merger reshapes
|
||||||
|
/// each `merge×merge` block of adjacent patches into a single
|
||||||
|
/// concatenated vector, then projects.
|
||||||
|
///
|
||||||
|
/// `grid_h` and `grid_w` must both be multiples of
|
||||||
|
/// `spatial_merge_size`. Returns
|
||||||
|
/// `(grid_h/merge × grid_w/merge, out_hidden_size)`.
|
||||||
|
fn forward(&self, tokens: &Tensor) -> Result<Tensor> {
|
||||||
|
let (gh, gw, h) = tokens.dims3()?;
|
||||||
|
let m = self.spatial_merge_size;
|
||||||
|
anyhow::ensure!(
|
||||||
|
gh.is_multiple_of(m) && gw.is_multiple_of(m),
|
||||||
|
"merger expects spatial dims divisible by merge_size={m}; got ({gh}, {gw})"
|
||||||
|
);
|
||||||
|
let tokens = self.norm.forward(tokens)?;
|
||||||
|
// (gh, gw, h) -> (gh/m, m, gw/m, m, h) -> (gh/m, gw/m, m, m, h)
|
||||||
|
// -> flatten last three -> (gh/m, gw/m, m*m*h) -> (N_lm, merge_input_dim)
|
||||||
|
let out_h = gh / m;
|
||||||
|
let out_w = gw / m;
|
||||||
|
let merged = tokens
|
||||||
|
.reshape((out_h, m, out_w, m, h))?
|
||||||
|
.permute((0, 2, 1, 3, 4))?
|
||||||
|
.contiguous()?
|
||||||
|
.reshape((out_h * out_w, self.merge_input_dim))?;
|
||||||
|
let hidden = self.fc2.forward(&gelu_tanh(&self.fc1.forward(&merged)?)?)?;
|
||||||
|
Ok(hidden)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 2D rotary position embedding for the vision tower. Each patch's
|
||||||
|
/// `head_dim` rotates by its `(row, col)` grid coordinates: the first
|
||||||
|
/// half of the rotary freqs are driven by the row position, the second
|
||||||
|
/// half by the column. Mirrors HF `Qwen3VLVisionRotaryEmbedding` +
|
||||||
|
/// `rot_pos_emb` (θ = 10000, `dim = head_dim/2`).
|
||||||
|
struct VisionRotaryEmbedding {
|
||||||
|
/// `(half,)` f32, `half = head_dim/4` freqs per spatial axis.
|
||||||
|
inv_freq: Vec<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VisionRotaryEmbedding {
|
||||||
|
fn new(head_dim: usize) -> Self {
|
||||||
|
// HF: Qwen3VLVisionRotaryEmbedding(head_dim // 2), theta 10000.
|
||||||
|
let dim = head_dim / 2;
|
||||||
|
let theta = 10000f32;
|
||||||
|
let inv_freq = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / theta.powf(i as f32 / dim as f32))
|
||||||
|
.collect();
|
||||||
|
Self { inv_freq }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// cos/sin for a `gh×gw` patch grid in **row-major** order. Returns
|
||||||
|
/// `(cos, sin)` each `(gh*gw, head_dim/2)`: per patch, the row-axis
|
||||||
|
/// freqs `row·inv_freq` followed by the col-axis freqs `col·inv_freq`
|
||||||
|
/// (then `rope_slow` duplicates them across the full head_dim).
|
||||||
|
fn cos_sin(
|
||||||
|
&self,
|
||||||
|
gh: usize,
|
||||||
|
gw: usize,
|
||||||
|
dev: &Device,
|
||||||
|
dtype: DType,
|
||||||
|
) -> candle_core::Result<(Tensor, Tensor)> {
|
||||||
|
let half = self.inv_freq.len();
|
||||||
|
let n = gh * gw;
|
||||||
|
let mut data = Vec::with_capacity(n * 2 * half);
|
||||||
|
for hi in 0..gh {
|
||||||
|
for wi in 0..gw {
|
||||||
|
for &f in &self.inv_freq {
|
||||||
|
data.push(hi as f32 * f);
|
||||||
|
}
|
||||||
|
for &f in &self.inv_freq {
|
||||||
|
data.push(wi as f32 * f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let freqs = Tensor::from_vec(data, (n, 2 * half), dev)?;
|
||||||
|
let cos = freqs.cos()?.to_dtype(dtype)?;
|
||||||
|
let sin = freqs.sin()?.to_dtype(dtype)?;
|
||||||
|
Ok((cos, sin))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The vision tower itself.
|
||||||
|
pub struct VisionTower {
|
||||||
|
/// Sum-collapsed temporal kernel (Conv2d, see module doc).
|
||||||
|
patch_embed: Conv2d,
|
||||||
|
pos_embed: Embedding,
|
||||||
|
rotary: VisionRotaryEmbedding,
|
||||||
|
blocks: Vec<VisionBlock>,
|
||||||
|
merger: VisionMerger,
|
||||||
|
config: VisionConfig,
|
||||||
|
dtype: DType,
|
||||||
|
device: Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VisionTower {
|
||||||
|
/// Load from a `ShardedVarBuilder` rooted at the safetensors
|
||||||
|
/// `model.visual.` prefix. Caller is responsible for the `pp` —
|
||||||
|
/// see `Qwen3_5ForCausalLM::new` (Stage A4).
|
||||||
|
pub fn load(cfg: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
|
||||||
|
let dtype = vb.dtype();
|
||||||
|
let device = vb.device().clone();
|
||||||
|
|
||||||
|
// patch_embed.proj is published as 5D Conv3d weight; we
|
||||||
|
// sum-collapse the temporal axis (size = temporal_patch_size)
|
||||||
|
// to get a 4D Conv2d kernel. This is exact for the static-
|
||||||
|
// image case where T = temporal_patch_size frames are
|
||||||
|
// identical (i.e. the input was duplicated along T).
|
||||||
|
let raw_weight = vb
|
||||||
|
.pp("patch_embed.proj")
|
||||||
|
.get(
|
||||||
|
(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.in_channels,
|
||||||
|
cfg.temporal_patch_size,
|
||||||
|
cfg.patch_size,
|
||||||
|
cfg.patch_size,
|
||||||
|
),
|
||||||
|
"weight",
|
||||||
|
)
|
||||||
|
.context("load model.visual.patch_embed.proj.weight (5D Conv3d kernel)")?;
|
||||||
|
// Sum along the temporal axis (dim 2) — see module doc-comment.
|
||||||
|
let folded = raw_weight.sum(2)?; // -> (hidden, in_channels, patch, patch)
|
||||||
|
let proj_bias = vb
|
||||||
|
.pp("patch_embed.proj")
|
||||||
|
.get(cfg.hidden_size, "bias")
|
||||||
|
.context("load model.visual.patch_embed.proj.bias")?;
|
||||||
|
let conv_cfg = Conv2dConfig {
|
||||||
|
stride: cfg.patch_size,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let patch_embed = Conv2d::new(folded, Some(proj_bias), conv_cfg);
|
||||||
|
|
||||||
|
let pos_embed_weight = vb
|
||||||
|
.pp("pos_embed")
|
||||||
|
.get((cfg.num_position_embeddings, cfg.hidden_size), "weight")
|
||||||
|
.context("load model.visual.pos_embed.weight")?;
|
||||||
|
let pos_embed = Embedding::new(pos_embed_weight, cfg.hidden_size);
|
||||||
|
let rotary = VisionRotaryEmbedding::new(cfg.hidden_size / cfg.num_heads);
|
||||||
|
|
||||||
|
let blocks_vb = vb.pp("blocks");
|
||||||
|
let mut blocks = Vec::with_capacity(cfg.depth);
|
||||||
|
for i in 0..cfg.depth {
|
||||||
|
blocks.push(
|
||||||
|
VisionBlock::load(&cfg, &blocks_vb.pp(i))
|
||||||
|
.with_context(|| format!("load vision block {i}"))?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let merger = VisionMerger::load(&cfg, &vb.pp("merger")).context("load vision merger")?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
patch_embed,
|
||||||
|
pos_embed,
|
||||||
|
rotary,
|
||||||
|
blocks,
|
||||||
|
merger,
|
||||||
|
config: cfg,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn config(&self) -> &VisionConfig {
|
||||||
|
&self.config
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of LM tokens this tower emits for an `(H, W)` pixel
|
||||||
|
/// image after the merger. Equal to
|
||||||
|
/// `(H / patch_size / spatial_merge_size) * (W / patch_size / spatial_merge_size)`.
|
||||||
|
pub fn lm_tokens_for(&self, h: u32, w: u32) -> usize {
|
||||||
|
let m = self.config.spatial_merge_size;
|
||||||
|
let patch = self.config.patch_size;
|
||||||
|
let gh = (h as usize) / patch / m;
|
||||||
|
let gw = (w as usize) / patch / m;
|
||||||
|
gh * gw * LM_TOKENS_PER_MERGE_GROUP
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Bilinearly interpolate the learned `pos_embed` grid (a
|
||||||
|
/// `num_grid_per_side × num_grid_per_side` table, 48×48 for Qwen3.6)
|
||||||
|
/// onto the actual `gh × gw` patch grid, in **row-major** patch
|
||||||
|
/// order. Port of the HF `fast_pos_embed_interpolate`: for each patch
|
||||||
|
/// at fractional grid coord `(linspace(0, ngrid-1, gh)[hi],
|
||||||
|
/// linspace(0, ngrid-1, gw)[wi])`, blend the 4 surrounding grid
|
||||||
|
/// entries by bilinear weights. Returns `(gh*gw, hidden)` in
|
||||||
|
/// `self.dtype`.
|
||||||
|
fn interpolated_pos_embed(&self, gh: usize, gw: usize) -> Result<Tensor> {
|
||||||
|
let ngrid = (self.config.num_position_embeddings as f64).sqrt().round() as usize;
|
||||||
|
anyhow::ensure!(
|
||||||
|
ngrid * ngrid == self.config.num_position_embeddings,
|
||||||
|
"num_position_embeddings {} is not a perfect square",
|
||||||
|
self.config.num_position_embeddings
|
||||||
|
);
|
||||||
|
// Evenly-spaced fractional indices into the [0, ngrid-1] grid.
|
||||||
|
let lin = |n: usize| -> Vec<f64> {
|
||||||
|
if n <= 1 {
|
||||||
|
vec![0.0]
|
||||||
|
} else {
|
||||||
|
let step = (ngrid - 1) as f64 / (n - 1) as f64;
|
||||||
|
(0..n).map(|i| i as f64 * step).collect()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let hs = lin(gh);
|
||||||
|
let ws = lin(gw);
|
||||||
|
let n = gh * gw;
|
||||||
|
|
||||||
|
// Four corner index sets + bilinear weight sets, row-major.
|
||||||
|
let mut idx: [Vec<u32>; 4] = [
|
||||||
|
Vec::with_capacity(n),
|
||||||
|
Vec::with_capacity(n),
|
||||||
|
Vec::with_capacity(n),
|
||||||
|
Vec::with_capacity(n),
|
||||||
|
];
|
||||||
|
let mut wts: [Vec<f32>; 4] = [
|
||||||
|
Vec::with_capacity(n),
|
||||||
|
Vec::with_capacity(n),
|
||||||
|
Vec::with_capacity(n),
|
||||||
|
Vec::with_capacity(n),
|
||||||
|
];
|
||||||
|
for &hv in &hs {
|
||||||
|
let hf = hv as usize; // floor (hv >= 0)
|
||||||
|
let hc = (hf + 1).min(ngrid - 1);
|
||||||
|
let dh = (hv - hf as f64) as f32;
|
||||||
|
for &wv in &ws {
|
||||||
|
let wf = wv as usize;
|
||||||
|
let wc = (wf + 1).min(ngrid - 1);
|
||||||
|
let dw = (wv - wf as f64) as f32;
|
||||||
|
idx[0].push((hf * ngrid + wf) as u32);
|
||||||
|
wts[0].push((1.0 - dh) * (1.0 - dw));
|
||||||
|
idx[1].push((hf * ngrid + wc) as u32);
|
||||||
|
wts[1].push((1.0 - dh) * dw);
|
||||||
|
idx[2].push((hc * ngrid + wf) as u32);
|
||||||
|
wts[2].push(dh * (1.0 - dw));
|
||||||
|
idx[3].push((hc * ngrid + wc) as u32);
|
||||||
|
wts[3].push(dh * dw);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut acc: Option<Tensor> = None;
|
||||||
|
for corner in 0..4 {
|
||||||
|
let idx_t = Tensor::from_vec(std::mem::take(&mut idx[corner]), (n,), &self.device)?;
|
||||||
|
let emb = self.pos_embed.forward(&idx_t)?; // (n, hidden), pos_embed dtype
|
||||||
|
let wt = Tensor::from_vec(std::mem::take(&mut wts[corner]), (n, 1), &self.device)?
|
||||||
|
.to_dtype(self.dtype)?;
|
||||||
|
let term = emb.broadcast_mul(&wt)?;
|
||||||
|
acc = Some(match acc {
|
||||||
|
Some(a) => a.add(&term)?,
|
||||||
|
None => term,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(acc.expect("4 corners accumulated"))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encode one image.
|
||||||
|
///
|
||||||
|
/// `image`: row-major `(3, H, W)` f32 tensor on `self.device`,
|
||||||
|
/// already normalised by `preprocess::preprocess`. Both `H` and
|
||||||
|
/// `W` must be multiples of `patch_size * spatial_merge_size`.
|
||||||
|
///
|
||||||
|
/// Returns `(N_lm, out_hidden_size)` — LM-side image tokens
|
||||||
|
/// ready to splice into the language model's input embeddings.
|
||||||
|
pub fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||||
|
let (c, h, w) = image.dims3()?;
|
||||||
|
anyhow::ensure!(
|
||||||
|
c == self.config.in_channels,
|
||||||
|
"image must have {} channels, got {c}",
|
||||||
|
self.config.in_channels
|
||||||
|
);
|
||||||
|
let patch = self.config.patch_size;
|
||||||
|
anyhow::ensure!(
|
||||||
|
h.is_multiple_of(patch) && w.is_multiple_of(patch),
|
||||||
|
"image dims must be multiples of patch_size={patch}; got ({h}, {w})"
|
||||||
|
);
|
||||||
|
let gh = h / patch;
|
||||||
|
let gw = w / patch;
|
||||||
|
let n_patches = gh * gw;
|
||||||
|
anyhow::ensure!(
|
||||||
|
n_patches <= self.config.num_position_embeddings,
|
||||||
|
"patch count {n_patches} exceeds pos_embed budget {}",
|
||||||
|
self.config.num_position_embeddings
|
||||||
|
);
|
||||||
|
|
||||||
|
// Add batch axis for conv: (1, 3, H, W) → (1, hidden, gh, gw)
|
||||||
|
// → (hidden, gh, gw) → permute to (gh, gw, hidden) → flatten to (N, hidden)
|
||||||
|
let x = image.unsqueeze(0)?.to_dtype(self.dtype)?;
|
||||||
|
let x = self.patch_embed.forward(&x)?;
|
||||||
|
let x = x.squeeze(0)?;
|
||||||
|
let x = x.permute((1, 2, 0))?.contiguous()?;
|
||||||
|
let x = x.reshape((n_patches, self.config.hidden_size))?;
|
||||||
|
|
||||||
|
// Learned absolute position embeddings. The `pos_embed` table is
|
||||||
|
// a `num_position_embeddings = num_grid_per_side²` learned grid
|
||||||
|
// (48×48 for Qwen3.6); for a `gh×gw` patch grid the reference
|
||||||
|
// (`fast_pos_embed_interpolate`) bilinearly interpolates that
|
||||||
|
// grid to `gh×gw`. The legacy path (a naive sequential lookup of
|
||||||
|
// the first `n_patches` rows) mis-maps the grid stride and
|
||||||
|
// scrambles spatial structure — kept only behind
|
||||||
|
// `NEURON_VISION_LEGACY_POS=1` for A/B comparison.
|
||||||
|
let pos = if vision_legacy_pos() {
|
||||||
|
let positions = Tensor::arange(0u32, n_patches as u32, &self.device)?;
|
||||||
|
self.pos_embed.forward(&positions)?
|
||||||
|
} else {
|
||||||
|
self.interpolated_pos_embed(gh, gw)?
|
||||||
|
};
|
||||||
|
let mut x = x.add(&pos)?;
|
||||||
|
|
||||||
|
// 2D vision rotary (row/col per patch), computed once and applied
|
||||||
|
// in every block's attention. Legacy escape hatch skips it.
|
||||||
|
let rotary = if vision_legacy_rope() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.rotary.cos_sin(gh, gw, &self.device, self.dtype)?)
|
||||||
|
};
|
||||||
|
let rotary_ref = rotary.as_ref();
|
||||||
|
|
||||||
|
for (i, block) in self.blocks.iter().enumerate() {
|
||||||
|
x = block
|
||||||
|
.forward(&x, rotary_ref)
|
||||||
|
.with_context(|| format!("vision block {i}"))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// (n_patches, hidden) → (gh, gw, hidden) for the merger.
|
||||||
|
let x = x.reshape((gh, gw, self.config.hidden_size))?;
|
||||||
|
self.merger.forward(&x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Manually load a candle_nn LayerNorm from a ShardedVarBuilder.
|
||||||
|
/// candle_nn's `layer_norm` builder takes `crate::VarBuilder`, not
|
||||||
|
/// `ShardedVarBuilder`, so the existing arch modules in this crate
|
||||||
|
/// uniformly do the manual load + struct construction pattern (see
|
||||||
|
/// `full_attn::load_linear_no_bias`). We follow suit here.
|
||||||
|
fn layer_norm(vb: ShardedVarBuilder, size: usize) -> Result<LayerNorm> {
|
||||||
|
let weight = vb
|
||||||
|
.get(size, "weight")
|
||||||
|
.with_context(|| format!("load LayerNorm.weight at '{}'", vb.prefix()))?;
|
||||||
|
let bias = vb
|
||||||
|
.get(size, "bias")
|
||||||
|
.with_context(|| format!("load LayerNorm.bias at '{}'", vb.prefix()))?;
|
||||||
|
Ok(LayerNorm::new(weight, bias, LAYER_NORM_EPS))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Manually load a candle_nn Linear (with bias) from a
|
||||||
|
/// ShardedVarBuilder. Same rationale as `layer_norm` above.
|
||||||
|
fn linear(vb: ShardedVarBuilder, in_dim: usize, out_dim: usize) -> Result<Linear> {
|
||||||
|
let weight = vb
|
||||||
|
.get((out_dim, in_dim), "weight")
|
||||||
|
.with_context(|| format!("load Linear.weight at '{}'", vb.prefix()))?;
|
||||||
|
let bias = vb
|
||||||
|
.get(out_dim, "bias")
|
||||||
|
.with_context(|| format!("load Linear.bias at '{}'", vb.prefix()))?;
|
||||||
|
Ok(Linear::new(weight, Some(bias)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// PyTorch's `gelu_pytorch_tanh` approximation — what the Qwen3.6
|
||||||
|
/// vision tower's `hidden_act` specifies. candle's `Tensor::gelu`
|
||||||
|
/// uses the exact erf-based GELU, so we compute the tanh
|
||||||
|
/// approximation explicitly:
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// gelu_tanh(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||||
|
/// ```
|
||||||
|
fn gelu_tanh(x: &Tensor) -> Result<Tensor> {
|
||||||
|
// sqrt(2 / pi) = 0.7978845608028654
|
||||||
|
const COEFF: f64 = 0.7978845608028654;
|
||||||
|
const KAPPA: f64 = 0.044715;
|
||||||
|
let x3 = x.powf(3.0)?;
|
||||||
|
let inner = (x + (x3 * KAPPA)?)?;
|
||||||
|
let inner = (inner * COEFF)?;
|
||||||
|
let t = inner.tanh()?;
|
||||||
|
let one_plus_t = (t + 1.0)?;
|
||||||
|
let out = (x * 0.5)?;
|
||||||
|
let out = out.broadcast_mul(&one_plus_t)?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use candle_core::{DType, Device};
|
||||||
|
|
||||||
|
/// Build a tiny VisionConfig usable on CPU with random weights.
|
||||||
|
/// Match the Qwen3.6 shape relations (depth-N stack, hidden mod
|
||||||
|
/// num_heads, intermediate_size > hidden_size) but with small
|
||||||
|
/// dims so tests run in milliseconds.
|
||||||
|
fn tiny_config() -> VisionConfig {
|
||||||
|
VisionConfig {
|
||||||
|
depth: 2,
|
||||||
|
hidden_size: 32,
|
||||||
|
intermediate_size: 64,
|
||||||
|
num_heads: 4,
|
||||||
|
num_position_embeddings: 64,
|
||||||
|
patch_size: 4,
|
||||||
|
temporal_patch_size: 2,
|
||||||
|
spatial_merge_size: 2,
|
||||||
|
in_channels: 3,
|
||||||
|
out_hidden_size: 48,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hand-construct a VisionTower with random weights. This is the
|
||||||
|
/// same trick `linear_attn::tests::forward_smoke_with_tiny_dimensions`
|
||||||
|
/// uses — bypass the safetensors-backed `ShardedVarBuilder` path
|
||||||
|
/// (which can't be built from in-memory tensors) and assemble the
|
||||||
|
/// struct fields directly. The real `VisionTower::load` is
|
||||||
|
/// exercised by the cuda-integration smoke test in Stage A6.
|
||||||
|
fn tiny_tower(cfg: &VisionConfig) -> VisionTower {
|
||||||
|
let device = Device::Cpu;
|
||||||
|
let dtype = DType::F32;
|
||||||
|
let zeros = |shape: &[usize]| Tensor::zeros(shape, dtype, &device).unwrap();
|
||||||
|
let ones = |shape: &[usize]| Tensor::ones(shape, dtype, &device).unwrap();
|
||||||
|
let randn = |shape: &[usize]| Tensor::randn(0_f32, 0.02, shape, &device).unwrap();
|
||||||
|
|
||||||
|
let patch_embed = Conv2d::new(
|
||||||
|
randn(&[
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.in_channels,
|
||||||
|
cfg.patch_size,
|
||||||
|
cfg.patch_size,
|
||||||
|
]),
|
||||||
|
Some(zeros(&[cfg.hidden_size])),
|
||||||
|
Conv2dConfig {
|
||||||
|
stride: cfg.patch_size,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
);
|
||||||
|
let pos_embed = Embedding::new(
|
||||||
|
randn(&[cfg.num_position_embeddings, cfg.hidden_size]),
|
||||||
|
cfg.hidden_size,
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut blocks = Vec::with_capacity(cfg.depth);
|
||||||
|
for _ in 0..cfg.depth {
|
||||||
|
let head_dim = cfg.hidden_size / cfg.num_heads;
|
||||||
|
blocks.push(VisionBlock {
|
||||||
|
norm1: LayerNorm::new(
|
||||||
|
ones(&[cfg.hidden_size]),
|
||||||
|
zeros(&[cfg.hidden_size]),
|
||||||
|
LAYER_NORM_EPS,
|
||||||
|
),
|
||||||
|
qkv: Linear::new(
|
||||||
|
randn(&[3 * cfg.hidden_size, cfg.hidden_size]),
|
||||||
|
Some(zeros(&[3 * cfg.hidden_size])),
|
||||||
|
),
|
||||||
|
proj: Linear::new(
|
||||||
|
randn(&[cfg.hidden_size, cfg.hidden_size]),
|
||||||
|
Some(zeros(&[cfg.hidden_size])),
|
||||||
|
),
|
||||||
|
norm2: LayerNorm::new(
|
||||||
|
ones(&[cfg.hidden_size]),
|
||||||
|
zeros(&[cfg.hidden_size]),
|
||||||
|
LAYER_NORM_EPS,
|
||||||
|
),
|
||||||
|
fc1: Linear::new(
|
||||||
|
randn(&[cfg.intermediate_size, cfg.hidden_size]),
|
||||||
|
Some(zeros(&[cfg.intermediate_size])),
|
||||||
|
),
|
||||||
|
fc2: Linear::new(
|
||||||
|
randn(&[cfg.hidden_size, cfg.intermediate_size]),
|
||||||
|
Some(zeros(&[cfg.hidden_size])),
|
||||||
|
),
|
||||||
|
num_heads: cfg.num_heads,
|
||||||
|
head_dim,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let merge_input_dim = cfg.hidden_size * cfg.spatial_merge_size * cfg.spatial_merge_size;
|
||||||
|
let merger = VisionMerger {
|
||||||
|
norm: LayerNorm::new(
|
||||||
|
ones(&[cfg.hidden_size]),
|
||||||
|
zeros(&[cfg.hidden_size]),
|
||||||
|
LAYER_NORM_EPS,
|
||||||
|
),
|
||||||
|
fc1: Linear::new(
|
||||||
|
randn(&[merge_input_dim, merge_input_dim]),
|
||||||
|
Some(zeros(&[merge_input_dim])),
|
||||||
|
),
|
||||||
|
fc2: Linear::new(
|
||||||
|
randn(&[cfg.out_hidden_size, merge_input_dim]),
|
||||||
|
Some(zeros(&[cfg.out_hidden_size])),
|
||||||
|
),
|
||||||
|
merge_input_dim,
|
||||||
|
spatial_merge_size: cfg.spatial_merge_size,
|
||||||
|
};
|
||||||
|
|
||||||
|
let rotary = VisionRotaryEmbedding::new(cfg.hidden_size / cfg.num_heads);
|
||||||
|
VisionTower {
|
||||||
|
patch_embed,
|
||||||
|
pos_embed,
|
||||||
|
rotary,
|
||||||
|
blocks,
|
||||||
|
merger,
|
||||||
|
config: cfg.clone(),
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn forward_with_random_weights_produces_finite_output() {
|
||||||
|
let cfg = tiny_config();
|
||||||
|
let tower = tiny_tower(&cfg);
|
||||||
|
|
||||||
|
// 16×16 image at patch_size=4 → 4×4 patches → after 2×2
|
||||||
|
// merge → 2×2 = 4 LM tokens of dim out_hidden_size.
|
||||||
|
let image = Tensor::randn(0_f32, 1.0, (3, 16, 16), &Device::Cpu).unwrap();
|
||||||
|
let out = tower.forward(&image).expect("forward");
|
||||||
|
let (n_lm, hidden) = out.dims2().unwrap();
|
||||||
|
assert_eq!(n_lm, 4);
|
||||||
|
assert_eq!(hidden, cfg.out_hidden_size);
|
||||||
|
|
||||||
|
// No NaN/Inf
|
||||||
|
let values: Vec<f32> = out.flatten_all().unwrap().to_vec1().unwrap();
|
||||||
|
assert!(
|
||||||
|
values.iter().all(|v| v.is_finite()),
|
||||||
|
"forward must produce finite values"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn interpolated_pos_embed_reduces_to_sequential_at_native_grid() {
|
||||||
|
// When the patch grid equals the pos_embed grid (gh=gw=ngrid),
|
||||||
|
// linspace(0,ngrid-1,ngrid) is the integer ladder, so every patch
|
||||||
|
// lands exactly on a grid node (dh=dw=0, corner-0 weight 1) and
|
||||||
|
// the bilinear result is the raw pos_embed rows in row-major
|
||||||
|
// order — i.e. identical to the legacy sequential lookup.
|
||||||
|
let cfg = tiny_config();
|
||||||
|
let tower = tiny_tower(&cfg);
|
||||||
|
let ngrid = (cfg.num_position_embeddings as f64).sqrt() as usize; // 8
|
||||||
|
let interp = tower.interpolated_pos_embed(ngrid, ngrid).unwrap();
|
||||||
|
let seq = tower
|
||||||
|
.pos_embed
|
||||||
|
.forward(&Tensor::arange(0u32, (ngrid * ngrid) as u32, &Device::Cpu).unwrap())
|
||||||
|
.unwrap();
|
||||||
|
let a: Vec<f32> = interp.flatten_all().unwrap().to_vec1().unwrap();
|
||||||
|
let b: Vec<f32> = seq.flatten_all().unwrap().to_vec1().unwrap();
|
||||||
|
assert_eq!(a.len(), b.len());
|
||||||
|
for (x, y) in a.iter().zip(b.iter()) {
|
||||||
|
assert!((x - y).abs() < 1e-5, "interp {x} vs seq {y}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn vision_rotary_row_col_structure() {
|
||||||
|
// head_dim 8 → rotary dim 4 → inv_freq over [0,2] → 2 freqs/axis.
|
||||||
|
let rot = VisionRotaryEmbedding::new(8);
|
||||||
|
assert_eq!(rot.inv_freq.len(), 2);
|
||||||
|
let (cos, sin) = rot.cos_sin(2, 2, &Device::Cpu, DType::F32).unwrap();
|
||||||
|
assert_eq!(cos.dims(), &[4, 4]); // 4 patches, head_dim/2 = 4 cols
|
||||||
|
|
||||||
|
// Patch (0,0): all freqs 0 → cos 1, sin 0.
|
||||||
|
let s0: Vec<f32> = sin.i(0).unwrap().to_vec1().unwrap();
|
||||||
|
assert!(s0.iter().all(|&s| s.abs() < 1e-6));
|
||||||
|
|
||||||
|
// Patch index 2 = grid (1,0): row=1 drives the first half, col=0
|
||||||
|
// leaves the second half at zero.
|
||||||
|
let s2: Vec<f32> = sin.i(2).unwrap().to_vec1().unwrap();
|
||||||
|
assert!(s2[0].abs() > 1e-6, "row half must be non-zero");
|
||||||
|
assert!(
|
||||||
|
s2[2].abs() < 1e-6 && s2[3].abs() < 1e-6,
|
||||||
|
"col half must be zero"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn lm_token_count_matches_grid() {
|
||||||
|
let cfg = tiny_config();
|
||||||
|
let tower = tiny_tower(&cfg);
|
||||||
|
// 16x16 image → 4x4 patches → 2x2 = 4 LM tokens
|
||||||
|
assert_eq!(tower.lm_tokens_for(16, 16), 4);
|
||||||
|
// 32x32 image → 8x8 patches → 4x4 = 16 LM tokens
|
||||||
|
assert_eq!(tower.lm_tokens_for(32, 32), 16);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_image_with_dims_not_multiple_of_patch() {
|
||||||
|
let cfg = tiny_config();
|
||||||
|
let tower = tiny_tower(&cfg);
|
||||||
|
let image = Tensor::randn(0_f32, 1.0, (3, 17, 17), &Device::Cpu).unwrap();
|
||||||
|
let err = tower.forward(&image).unwrap_err();
|
||||||
|
assert!(format!("{err:#}").contains("patch_size"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_image_with_wrong_channel_count() {
|
||||||
|
let cfg = tiny_config();
|
||||||
|
let tower = tiny_tower(&cfg);
|
||||||
|
let image = Tensor::randn(0_f32, 1.0, (4, 16, 16), &Device::Cpu).unwrap();
|
||||||
|
let err = tower.forward(&image).unwrap_err();
|
||||||
|
assert!(format!("{err:#}").contains("channels"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn gelu_tanh_matches_known_values() {
|
||||||
|
// Reference values for gelu_pytorch_tanh from PyTorch:
|
||||||
|
// gelu_tanh(0.0) = 0.0
|
||||||
|
// gelu_tanh(1.0) ≈ 0.8411920071
|
||||||
|
// gelu_tanh(-1.0) ≈ -0.1588079929
|
||||||
|
let x = Tensor::new(&[0.0_f32, 1.0, -1.0], &Device::Cpu).unwrap();
|
||||||
|
let y = gelu_tanh(&x).unwrap();
|
||||||
|
let v: Vec<f32> = y.to_vec1().unwrap();
|
||||||
|
assert!((v[0]).abs() < 1e-6, "gelu_tanh(0) ≈ 0, got {}", v[0]);
|
||||||
|
assert!(
|
||||||
|
(v[1] - 0.841_192_f32).abs() < 1e-5,
|
||||||
|
"gelu_tanh(1) ≈ 0.84119, got {}",
|
||||||
|
v[1]
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
(v[2] - -0.158_808_f32).abs() < 1e-5,
|
||||||
|
"gelu_tanh(-1) ≈ -0.15881, got {}",
|
||||||
|
v[2]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -43,7 +43,7 @@
|
|||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use cortex_core::openai::{ChatMessage, MessageContent};
|
use cortex_core::openai::{ChatMessage, MessageContent};
|
||||||
use minijinja::Environment;
|
use minijinja::{Environment, Error as MjError, ErrorKind as MjErrorKind, Value as MjValue};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
@@ -65,12 +65,55 @@ pub fn chat_templates_enabled() -> bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convenience: probe for `tokenizer_config.json` in the same
|
/// Probe for the model's chat template in the same directory the
|
||||||
/// directory the tokenizer was loaded from. Both files come from
|
/// tokenizer was loaded from, following HuggingFace `transformers`
|
||||||
/// the same HuggingFace snapshot in the hf-hub cache, so the
|
/// precedence: a standalone `chat_template.jinja` (then
|
||||||
/// sibling path is reliable.
|
/// `chat_template.json`) wins over the `chat_template` field in
|
||||||
|
/// `tokenizer_config.json`.
|
||||||
|
///
|
||||||
|
/// This matters for multimodal models: Qwen3-VL / Qwen3.6 ship their
|
||||||
|
/// vision-aware template (the one that emits
|
||||||
|
/// `<|vision_start|><|image_pad|><|vision_end|>` per image) **only** in
|
||||||
|
/// `chat_template.jinja`, and may not ship a `tokenizer_config.json` at
|
||||||
|
/// all. Reading `tokenizer_config.json` alone returned `None`, which
|
||||||
|
/// dropped image content into the text-only `format_qwen3_prompt`
|
||||||
|
/// fallback — so image requests rendered zero `<|image_pad|>` tokens
|
||||||
|
/// and the vision path bailed on the count mismatch.
|
||||||
pub fn load_chat_template_alongside(tokenizer_json_path: &Path) -> Option<String> {
|
pub fn load_chat_template_alongside(tokenizer_json_path: &Path) -> Option<String> {
|
||||||
let parent = tokenizer_json_path.parent()?;
|
let parent = tokenizer_json_path.parent()?;
|
||||||
|
|
||||||
|
// 1. Standalone Jinja file — raw template text, highest priority.
|
||||||
|
let jinja_path = parent.join("chat_template.jinja");
|
||||||
|
match std::fs::read_to_string(&jinja_path) {
|
||||||
|
Ok(text) if !text.trim().is_empty() => {
|
||||||
|
tracing::info!(
|
||||||
|
path = %jinja_path.display(),
|
||||||
|
"chat_template: loaded standalone chat_template.jinja"
|
||||||
|
);
|
||||||
|
return Some(text);
|
||||||
|
}
|
||||||
|
Ok(_) => {
|
||||||
|
tracing::warn!(
|
||||||
|
path = %jinja_path.display(),
|
||||||
|
"chat_template: chat_template.jinja present but empty; trying other sources"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(_) => {} // absent — fall through, common case
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Standalone JSON file — `{"chat_template": "..."}` form.
|
||||||
|
let json_path = parent.join("chat_template.json");
|
||||||
|
if json_path.exists()
|
||||||
|
&& let Some(t) = load_chat_template_from(&json_path)
|
||||||
|
{
|
||||||
|
tracing::info!(
|
||||||
|
path = %json_path.display(),
|
||||||
|
"chat_template: loaded standalone chat_template.json"
|
||||||
|
);
|
||||||
|
return Some(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. The `chat_template` field inside tokenizer_config.json.
|
||||||
let config_path = parent.join("tokenizer_config.json");
|
let config_path = parent.join("tokenizer_config.json");
|
||||||
load_chat_template_from(&config_path)
|
load_chat_template_from(&config_path)
|
||||||
}
|
}
|
||||||
@@ -148,6 +191,25 @@ pub fn render_chat_template(
|
|||||||
kwargs: &Value,
|
kwargs: &Value,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
let mut env = Environment::new();
|
let mut env = Environment::new();
|
||||||
|
|
||||||
|
// HF chat templates are authored against Python's Jinja2 with its
|
||||||
|
// string semantics. Bridge the two so real model templates render:
|
||||||
|
//
|
||||||
|
// - `pycompat::unknown_method_callback` supplies Python str/list/dict
|
||||||
|
// methods minijinja lacks natively (`startswith`, `endswith`,
|
||||||
|
// `split`, `rstrip`, `lstrip`, …) — the Qwen3.6 template uses
|
||||||
|
// several in its think-block and tool-response handling.
|
||||||
|
// - `raise_exception` is the global HF templates call to reject
|
||||||
|
// malformed inputs (e.g. an image in a system message). Map it to
|
||||||
|
// a render error so the caller falls back / surfaces it.
|
||||||
|
env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
|
||||||
|
env.add_function(
|
||||||
|
"raise_exception",
|
||||||
|
|msg: String| -> Result<MjValue, MjError> {
|
||||||
|
Err(MjError::new(MjErrorKind::InvalidOperation, msg))
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
// Compile the template against a fixed name so error messages
|
// Compile the template against a fixed name so error messages
|
||||||
// surface "chat_template" rather than `<template>`.
|
// surface "chat_template" rather than `<template>`.
|
||||||
env.add_template("chat_template", template)
|
env.add_template("chat_template", template)
|
||||||
@@ -210,6 +272,114 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
|
/// Reproduces the Qwen3.6 vision template's image-insertion
|
||||||
|
/// condition against the OpenAI `image_url` content-part shape our
|
||||||
|
/// renderer forwards. Confirms minijinja's `'image_url' in item`
|
||||||
|
/// matches a serde_json object that carries that key — i.e. the
|
||||||
|
/// template *can* emit `<|image_pad|>` for our parts.
|
||||||
|
#[test]
|
||||||
|
fn image_url_part_renders_image_pad() {
|
||||||
|
// Condition copied from doc/vision-qwen3_6-spec.md (lines 8-18
|
||||||
|
// of the real chat_template.jinja).
|
||||||
|
let template = "{%- for message in messages -%}\
|
||||||
|
{%- if message.content is string -%}\
|
||||||
|
{{ message.content }}\
|
||||||
|
{%- else -%}\
|
||||||
|
{%- for item in message.content -%}\
|
||||||
|
{%- if 'image' in item or 'image_url' in item or item.type == 'image' -%}\
|
||||||
|
<|vision_start|><|image_pad|><|vision_end|>\
|
||||||
|
{%- elif item.type == 'text' -%}\
|
||||||
|
{{ item.text }}\
|
||||||
|
{%- endif -%}\
|
||||||
|
{%- endfor -%}\
|
||||||
|
{%- endif -%}\
|
||||||
|
{%- endfor -%}";
|
||||||
|
let messages = vec![ChatMessage {
|
||||||
|
role: "user".into(),
|
||||||
|
content: MessageContent::Parts(vec![
|
||||||
|
json!({"type": "text", "text": "what is this?"}),
|
||||||
|
json!({"type": "image_url", "image_url": {"url": "data:image/png;base64,AAA="}}),
|
||||||
|
]),
|
||||||
|
extra: Value::Object(Default::default()),
|
||||||
|
}];
|
||||||
|
let out = render_chat_template(template, &messages, &Value::Null, &Value::Null)
|
||||||
|
.expect("render should succeed");
|
||||||
|
assert!(
|
||||||
|
out.contains("<|image_pad|>"),
|
||||||
|
"expected the image_url part to emit <|image_pad|>; rendered: {out:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `chat_template.jinja` must win over `tokenizer_config.json`'s
|
||||||
|
/// `chat_template` field — the transformers precedence Qwen3.6
|
||||||
|
/// relies on (its vision template ships only in the `.jinja` file).
|
||||||
|
#[test]
|
||||||
|
fn standalone_jinja_template_takes_precedence() {
|
||||||
|
let dir = std::env::temp_dir().join(format!(
|
||||||
|
"neuron_ct_precedence_{}_{}",
|
||||||
|
std::process::id(),
|
||||||
|
line!()
|
||||||
|
));
|
||||||
|
std::fs::create_dir_all(&dir).unwrap();
|
||||||
|
std::fs::write(dir.join("chat_template.jinja"), "FROM_JINJA").unwrap();
|
||||||
|
std::fs::write(
|
||||||
|
dir.join("tokenizer_config.json"),
|
||||||
|
r#"{"chat_template": "FROM_CONFIG"}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
// tokenizer_json_path is the sibling the loader takes a parent of.
|
||||||
|
let got = load_chat_template_alongside(&dir.join("tokenizer.json"));
|
||||||
|
std::fs::remove_dir_all(&dir).ok();
|
||||||
|
assert_eq!(got.as_deref(), Some("FROM_JINJA"));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// With no standalone file, fall back to the tokenizer_config.json
|
||||||
|
/// field — the text-only path stays unchanged.
|
||||||
|
#[test]
|
||||||
|
fn falls_back_to_tokenizer_config_when_no_standalone() {
|
||||||
|
let dir = std::env::temp_dir().join(format!(
|
||||||
|
"neuron_ct_fallback_{}_{}",
|
||||||
|
std::process::id(),
|
||||||
|
line!()
|
||||||
|
));
|
||||||
|
std::fs::create_dir_all(&dir).unwrap();
|
||||||
|
std::fs::write(
|
||||||
|
dir.join("tokenizer_config.json"),
|
||||||
|
r#"{"chat_template": "FROM_CONFIG"}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let got = load_chat_template_alongside(&dir.join("tokenizer.json"));
|
||||||
|
std::fs::remove_dir_all(&dir).ok();
|
||||||
|
assert_eq!(got.as_deref(), Some("FROM_CONFIG"));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The *actual* Qwen3.6-27B `chat_template.jinja` (verbatim from
|
||||||
|
/// beast's HF cache) must render in minijinja and emit exactly one
|
||||||
|
/// `<|image_pad|>` for a text+image user turn. This is the real
|
||||||
|
/// end-to-end check the unit tests above only approximate — it
|
||||||
|
/// catches any minijinja incompatibility (namespace, macros,
|
||||||
|
/// reverse slice, string methods) before it reaches production.
|
||||||
|
#[test]
|
||||||
|
fn real_qwen3_6_template_renders_one_image_pad() {
|
||||||
|
let template = include_str!("testdata/qwen3_6_chat_template.jinja");
|
||||||
|
let messages = vec![ChatMessage {
|
||||||
|
role: "user".into(),
|
||||||
|
content: MessageContent::Parts(vec![
|
||||||
|
json!({"type": "text", "text": "what is this?"}),
|
||||||
|
json!({"type": "image_url", "image_url": {"url": "data:image/png;base64,AAA="}}),
|
||||||
|
]),
|
||||||
|
extra: Value::Object(Default::default()),
|
||||||
|
}];
|
||||||
|
let out = render_chat_template(template, &messages, &Value::Null, &Value::Null)
|
||||||
|
.expect("real Qwen3.6 template should render in minijinja");
|
||||||
|
let pads = out.matches("<|image_pad|>").count();
|
||||||
|
assert_eq!(
|
||||||
|
pads, 1,
|
||||||
|
"expected exactly one <|image_pad|>; rendered:\n{out}"
|
||||||
|
);
|
||||||
|
assert!(out.contains("<|vision_start|>") && out.contains("<|vision_end|>"));
|
||||||
|
}
|
||||||
|
|
||||||
fn user_msg(text: &str) -> ChatMessage {
|
fn user_msg(text: &str) -> ChatMessage {
|
||||||
ChatMessage {
|
ChatMessage {
|
||||||
role: "user".into(),
|
role: "user".into(),
|
||||||
|
|||||||
@@ -16,10 +16,11 @@
|
|||||||
use crate::harness::candle::ModelArch;
|
use crate::harness::candle::ModelArch;
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
use crate::harness::device_worker::jobs::TpHandle;
|
use crate::harness::device_worker::jobs::TpHandle;
|
||||||
use crate::harness::device_worker::jobs::{ArchHandle, Job};
|
use crate::harness::device_worker::jobs::{ArchHandle, ImageInput, Job};
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
use crate::harness::tp::TpLeaderModel;
|
use crate::harness::tp::TpLeaderModel;
|
||||||
use crate::harness::tp::nccl_state::NcclState;
|
use crate::harness::tp::nccl_state::NcclState;
|
||||||
|
use anyhow::Context as _;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
@@ -158,6 +159,35 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
|||||||
let result = forward_logits(&mut state, handle, &tokens, offset);
|
let result = forward_logits(&mut state, handle, &tokens, offset);
|
||||||
let _ = reply.send(result);
|
let _ = reply.send(result);
|
||||||
}
|
}
|
||||||
|
Job::EncodeImage {
|
||||||
|
handle,
|
||||||
|
pixels,
|
||||||
|
c,
|
||||||
|
h,
|
||||||
|
w,
|
||||||
|
reply,
|
||||||
|
} => {
|
||||||
|
let result = encode_image(&mut state, handle, pixels, c, h, w);
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
|
Job::ForwardLogitsWithImages {
|
||||||
|
handle,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
images,
|
||||||
|
image_token_id,
|
||||||
|
reply,
|
||||||
|
} => {
|
||||||
|
let result = forward_logits_with_images(
|
||||||
|
&mut state,
|
||||||
|
handle,
|
||||||
|
&tokens,
|
||||||
|
offset,
|
||||||
|
images,
|
||||||
|
image_token_id,
|
||||||
|
);
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
Job::NcclInit {
|
Job::NcclInit {
|
||||||
cfg,
|
cfg,
|
||||||
comm_id_hex,
|
comm_id_hex,
|
||||||
@@ -171,6 +201,16 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
|||||||
let _ = reply.send(resp);
|
let _ = reply.send(resp);
|
||||||
}
|
}
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::GetLeaderComm { reply } => {
|
||||||
|
// Clone the leader's Arc<Comm> out for the async-side
|
||||||
|
// watchdog. `None` before NcclInit. (#17 Stage 2)
|
||||||
|
let comm = state
|
||||||
|
.nccl
|
||||||
|
.comm()
|
||||||
|
.map(crate::harness::tp::nccl_state::SendComm);
|
||||||
|
let _ = reply.send(comm);
|
||||||
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
Job::TpLoadShard {
|
Job::TpLoadShard {
|
||||||
model_id,
|
model_id,
|
||||||
config_json,
|
config_json,
|
||||||
@@ -232,6 +272,27 @@ pub(crate) fn run(device_index: u32, rx: Receiver<Job>, poisoned: Arc<AtomicBool
|
|||||||
let result = tp_forward_logits(&mut state, handle, &tokens, offset);
|
let result = tp_forward_logits(&mut state, handle, &tokens, offset);
|
||||||
let _ = reply.send(result);
|
let _ = reply.send(result);
|
||||||
}
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::TpForwardLogitsWithImages {
|
||||||
|
handle,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
image_token_id,
|
||||||
|
image_data_uris,
|
||||||
|
chunk_size,
|
||||||
|
reply,
|
||||||
|
} => {
|
||||||
|
let result = tp_forward_logits_with_images(
|
||||||
|
&mut state,
|
||||||
|
handle,
|
||||||
|
&tokens,
|
||||||
|
offset,
|
||||||
|
image_token_id,
|
||||||
|
&image_data_uris,
|
||||||
|
chunk_size,
|
||||||
|
);
|
||||||
|
let _ = reply.send(result);
|
||||||
|
}
|
||||||
// Handled by the matches!() check above; reaching here
|
// Handled by the matches!() check above; reaching here
|
||||||
// means a Shutdown slipped past which is a bug.
|
// means a Shutdown slipped past which is a bug.
|
||||||
Job::Shutdown => unreachable!("Shutdown should break above"),
|
Job::Shutdown => unreachable!("Shutdown should break above"),
|
||||||
@@ -704,6 +765,61 @@ fn tp_forward_logits(
|
|||||||
Ok(values)
|
Ok(values)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Image-bearing leader forward (rank 0). Preprocesses each source
|
||||||
|
/// `image_data_uris` entry through the same deterministic
|
||||||
|
/// `preprocess_data_uri` every rank runs, uploads to the leader's
|
||||||
|
/// device, encodes + splices + forwards via
|
||||||
|
/// `TpLeaderModel::forward_with_images`, and copies the `[vocab]`
|
||||||
|
/// logits to CPU. Mirrors the single-GPU `forward_logits_with_images`
|
||||||
|
/// but on the TP leader's replicated tower.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn tp_forward_logits_with_images(
|
||||||
|
state: &mut DeviceWorkerState,
|
||||||
|
handle: TpHandle,
|
||||||
|
tokens: &[u32],
|
||||||
|
offset: usize,
|
||||||
|
image_token_id: u32,
|
||||||
|
image_data_uris: &[String],
|
||||||
|
chunk_size: usize,
|
||||||
|
) -> anyhow::Result<Vec<f32>> {
|
||||||
|
use crate::harness::preprocess::{PreprocessProfile, preprocess_data_uri};
|
||||||
|
use candle_core::{DType, Tensor};
|
||||||
|
|
||||||
|
if image_data_uris.is_empty() {
|
||||||
|
anyhow::bail!("TpForwardLogitsWithImages dispatched with zero images");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preprocess every image into a device-resident (C, H, W) tensor at
|
||||||
|
// its native-aspect resized dims (#14). Same `smart_resize` + decode
|
||||||
|
// path the subprocess workers run, so the encoded embeddings — and
|
||||||
|
// the per-image grids derived from these dims — match across ranks
|
||||||
|
// bit-for-bit.
|
||||||
|
let profile = PreprocessProfile::qwen3_6();
|
||||||
|
let mut pixels: Vec<Tensor> = Vec::with_capacity(image_data_uris.len());
|
||||||
|
for (idx, uri) in image_data_uris.iter().enumerate() {
|
||||||
|
let (px, h, w) = preprocess_data_uri(uri, &profile)
|
||||||
|
.with_context(|| format!("preprocess image[{idx}] (TP leader)"))?;
|
||||||
|
let t = Tensor::from_vec(px, (3, h as usize, w as usize), &state.device)?;
|
||||||
|
pixels.push(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
let model = state.tp_models.get_mut(&handle).ok_or_else(|| {
|
||||||
|
anyhow::anyhow!(
|
||||||
|
"TpForwardLogitsWithImages: no model for handle {}",
|
||||||
|
handle.0
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Chunked prefill (encode once, splice per chunk) — bounded
|
||||||
|
// activation, in lockstep with the subprocess ranks.
|
||||||
|
let logits =
|
||||||
|
model.prefill_with_images_chunked(tokens, offset, &pixels, image_token_id, chunk_size)?;
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?;
|
||||||
|
let logits = logits.to_dtype(DType::F32)?.flatten_all()?;
|
||||||
|
let values = logits.to_vec1::<f32>()?;
|
||||||
|
Ok(values)
|
||||||
|
}
|
||||||
|
|
||||||
/// Forward step + copy the `[vocab]` logits to a CPU `Vec<f32>` ready
|
/// Forward step + copy the `[vocab]` logits to a CPU `Vec<f32>` ready
|
||||||
/// for sampling on the async caller. The model's `device()` (CUDA or
|
/// for sampling on the async caller. The model's `device()` (CUDA or
|
||||||
/// CPU) determines where the kernel runs; this fn doesn't care.
|
/// CPU) determines where the kernel runs; this fn doesn't care.
|
||||||
@@ -740,6 +856,119 @@ fn forward_logits(
|
|||||||
Ok(values)
|
Ok(values)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Run the LM forward with vision-tower image splicing. Stage B3.
|
||||||
|
///
|
||||||
|
/// Encodes each image through the vision tower (`VisionTower::forward`,
|
||||||
|
/// dispatched via `ModelArch::encode_image`), concatenates the
|
||||||
|
/// resulting embeddings into a single `(N_total, hidden)` tensor, and
|
||||||
|
/// passes it to `ModelArch::forward_with_vision` along with the
|
||||||
|
/// prompt-expanded `tokens`. Image embeddings never leave the device.
|
||||||
|
///
|
||||||
|
/// Returns CPU `[vocab]` logits — same shape contract as
|
||||||
|
/// `ForwardLogits` so the async sampler doesn't have to branch on the
|
||||||
|
/// presence of images.
|
||||||
|
fn forward_logits_with_images(
|
||||||
|
state: &mut DeviceWorkerState,
|
||||||
|
handle: ArchHandle,
|
||||||
|
tokens: &[u32],
|
||||||
|
offset: usize,
|
||||||
|
images: Vec<ImageInput>,
|
||||||
|
image_token_id: u32,
|
||||||
|
) -> anyhow::Result<Vec<f32>> {
|
||||||
|
use candle_core::{DType, Tensor};
|
||||||
|
|
||||||
|
if images.is_empty() {
|
||||||
|
anyhow::bail!("ForwardLogitsWithImages dispatched with zero images");
|
||||||
|
}
|
||||||
|
|
||||||
|
let arch = state.models.get_mut(&handle).ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("ForwardLogitsWithImages: no model for handle {}", handle.0)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// pixel→LM-grid divisor (patch×merge) for this tower; each image's
|
||||||
|
// LM grid is (h/factor, w/factor) (#14 dynamic resolution).
|
||||||
|
let factor = arch.vision_grid_factor().ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("ForwardLogitsWithImages: loaded model has no vision tower")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Encode every image on the worker's device, collecting per-image
|
||||||
|
// post-merger embeddings as device-resident tensors plus their LM
|
||||||
|
// grids (for the interleaved-M-RoPE position ids).
|
||||||
|
let mut per_image: Vec<Tensor> = Vec::with_capacity(images.len());
|
||||||
|
let mut grids: Vec<(usize, usize)> = Vec::with_capacity(images.len());
|
||||||
|
for (idx, img) in images.into_iter().enumerate() {
|
||||||
|
anyhow::ensure!(
|
||||||
|
img.pixels.len() == img.c * img.h * img.w,
|
||||||
|
"ForwardLogitsWithImages: image[{idx}] pixels length {} does not match shape ({}, {}, {})",
|
||||||
|
img.pixels.len(),
|
||||||
|
img.c,
|
||||||
|
img.h,
|
||||||
|
img.w,
|
||||||
|
);
|
||||||
|
grids.push((img.h / factor, img.w / factor));
|
||||||
|
let image = Tensor::from_vec(img.pixels, (img.c, img.h, img.w), &state.device)?;
|
||||||
|
let embed = arch
|
||||||
|
.encode_image(&image)
|
||||||
|
.with_context(|| format!("encode image[{idx}]"))?;
|
||||||
|
per_image.push(embed);
|
||||||
|
}
|
||||||
|
// Concatenate per-image embeddings along the patch axis →
|
||||||
|
// (sum_of_patches, hidden). `Tensor::cat` keeps the result
|
||||||
|
// device-resident.
|
||||||
|
let image_embeds = Tensor::cat(&per_image.iter().collect::<Vec<_>>(), 0)?;
|
||||||
|
|
||||||
|
let input = Tensor::new(tokens, &state.device)?.unsqueeze(0)?;
|
||||||
|
let logits = arch.forward_with_vision(&input, offset, &image_embeds, image_token_id, &grids)?;
|
||||||
|
let values = logits
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.flatten_all()?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
Ok(values)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the vision tower on a single preprocessed image. Stage A5.
|
||||||
|
///
|
||||||
|
/// `pixels` is a row-major `(c, h, w)` f32 image that the async-side
|
||||||
|
/// `harness::preprocess` produced. We reconstruct the tensor on the
|
||||||
|
/// worker's device (the same device the model was loaded against),
|
||||||
|
/// call `arch.encode_image`, and copy the resulting
|
||||||
|
/// `(N_lm_tokens, hidden_size)` embedding back to CPU f32.
|
||||||
|
///
|
||||||
|
/// Returns the flattened embedding as a `Vec<f32>` — the caller knows
|
||||||
|
/// the LM-side token count from `VisionTower::lm_tokens_for(h, w)`
|
||||||
|
/// and reshapes accordingly. Stage B introduces a device-resident
|
||||||
|
/// embedding-slab variant that avoids this round-trip when the next
|
||||||
|
/// forward call needs the result.
|
||||||
|
fn encode_image(
|
||||||
|
state: &mut DeviceWorkerState,
|
||||||
|
handle: ArchHandle,
|
||||||
|
pixels: Vec<f32>,
|
||||||
|
c: usize,
|
||||||
|
h: usize,
|
||||||
|
w: usize,
|
||||||
|
) -> anyhow::Result<Vec<f32>> {
|
||||||
|
use candle_core::{DType, Tensor};
|
||||||
|
|
||||||
|
anyhow::ensure!(
|
||||||
|
pixels.len() == c * h * w,
|
||||||
|
"EncodeImage: pixels length {} does not match shape ({c}, {h}, {w})",
|
||||||
|
pixels.len()
|
||||||
|
);
|
||||||
|
let image = Tensor::from_vec(pixels, (c, h, w), &state.device)?;
|
||||||
|
|
||||||
|
let arch = state
|
||||||
|
.models
|
||||||
|
.get(&handle)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("EncodeImage: no model for handle {}", handle.0))?;
|
||||||
|
|
||||||
|
let embed = arch.encode_image(&image)?;
|
||||||
|
let values = embed
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.flatten_all()?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
Ok(values)
|
||||||
|
}
|
||||||
|
|
||||||
/// Reply to a job with the poisoned-worker error. Used when the worker
|
/// Reply to a job with the poisoned-worker error. Used when the worker
|
||||||
/// has flipped into drain-only mode after a CUDA driver error.
|
/// has flipped into drain-only mode after a CUDA driver error.
|
||||||
///
|
///
|
||||||
@@ -773,12 +1002,22 @@ fn drain_poisoned(job: Job, device_index: u32) {
|
|||||||
Job::ForwardLogits { reply, .. } => {
|
Job::ForwardLogits { reply, .. } => {
|
||||||
let _ = reply.send(Err(err()));
|
let _ = reply.send(Err(err()));
|
||||||
}
|
}
|
||||||
|
Job::EncodeImage { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
|
Job::ForwardLogitsWithImages { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
Job::NcclInit { reply, .. } => {
|
Job::NcclInit { reply, .. } => {
|
||||||
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
|
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
|
||||||
kind: "device_worker_poisoned".into(),
|
kind: "device_worker_poisoned".into(),
|
||||||
message: format!("device worker {device_index} poisoned"),
|
message: format!("device worker {device_index} poisoned"),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::GetLeaderComm { reply } => {
|
||||||
|
let _ = reply.send(None);
|
||||||
|
}
|
||||||
Job::NcclSanity { reply } => {
|
Job::NcclSanity { reply } => {
|
||||||
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
|
let _ = reply.send(crate::harness::tp::rpc::WorkerResponse::Error {
|
||||||
kind: "device_worker_poisoned".into(),
|
kind: "device_worker_poisoned".into(),
|
||||||
@@ -801,6 +1040,10 @@ fn drain_poisoned(job: Job, device_index: u32) {
|
|||||||
Job::TpForwardLogits { reply, .. } => {
|
Job::TpForwardLogits { reply, .. } => {
|
||||||
let _ = reply.send(Err(err()));
|
let _ = reply.send(Err(err()));
|
||||||
}
|
}
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
Job::TpForwardLogitsWithImages { reply, .. } => {
|
||||||
|
let _ = reply.send(Err(err()));
|
||||||
|
}
|
||||||
Job::Shutdown => {
|
Job::Shutdown => {
|
||||||
// Filtered by the matches!() guard in run(); reaching
|
// Filtered by the matches!() guard in run(); reaching
|
||||||
// here would be a logic error.
|
// here would be a logic error.
|
||||||
|
|||||||
@@ -28,6 +28,29 @@ pub struct ArchHandle(pub u64);
|
|||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub struct TpHandle(pub u64);
|
pub struct TpHandle(pub u64);
|
||||||
|
|
||||||
|
/// One image payload for `Job::ForwardLogitsWithImages` /
|
||||||
|
/// `Job::EncodeImage`. Pixels are row-major `(c, h, w)` f32 — the
|
||||||
|
/// shape `harness::preprocess::preprocess` produces. Carries the
|
||||||
|
/// shape inline since `Vec<f32>` is rank-1.
|
||||||
|
///
|
||||||
|
/// `Clone` so the vision-aware dispatch in `chat_completion` can
|
||||||
|
/// match `&vision_route` (carrying borrowed images) and still hand
|
||||||
|
/// owned `Vec<ImageInput>` to the worker job. The clone cost is one
|
||||||
|
/// pixel-buffer memcpy per image — now variable with dynamic resolution
|
||||||
|
/// (#14): `3 × h × w × 4` bytes, up to ~6.3 MiB at the default 1024²
|
||||||
|
/// `max_pixels` budget.
|
||||||
|
///
|
||||||
|
/// `h`/`w` are the **resized** dims (factor-aligned), so the per-image LM
|
||||||
|
/// grid is `(h/factor, w/factor)` — derived downstream for the splice
|
||||||
|
/// and the interleaved-M-RoPE position ids.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ImageInput {
|
||||||
|
pub pixels: Vec<f32>,
|
||||||
|
pub c: usize,
|
||||||
|
pub h: usize,
|
||||||
|
pub w: usize,
|
||||||
|
}
|
||||||
|
|
||||||
/// One unit of work for the device worker.
|
/// One unit of work for the device worker.
|
||||||
///
|
///
|
||||||
/// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the
|
/// Phase 1 had only `QueryVram` and `Shutdown`. Phase 2 adds the
|
||||||
@@ -94,6 +117,58 @@ pub enum Job {
|
|||||||
offset: usize,
|
offset: usize,
|
||||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||||
},
|
},
|
||||||
|
/// Run the LM forward with vision splicing in one round-trip.
|
||||||
|
/// Stage B3 of the vision plan.
|
||||||
|
///
|
||||||
|
/// Inputs:
|
||||||
|
/// - `tokens`: prompt-expanded token ids (the caller has already
|
||||||
|
/// replaced each `<|image_pad|>` with N copies per the
|
||||||
|
/// per-image patch count, so `tokens` already contains exactly
|
||||||
|
/// `sum(n_i)` `image_token_id` entries across all images).
|
||||||
|
/// - `offset`: KV-cache position (same contract as `ForwardLogits`).
|
||||||
|
/// - `images`: one entry per image — preprocessed pixels plus the
|
||||||
|
/// `(c, h, w)` shape. Images are encoded on the worker via the
|
||||||
|
/// model's vision tower (`VisionTower::forward`), concatenated
|
||||||
|
/// in order, and spliced into the LM input embeddings at
|
||||||
|
/// `image_token_id` positions.
|
||||||
|
/// - `image_token_id`: the sentinel token (248056 for Qwen3.6).
|
||||||
|
///
|
||||||
|
/// Returns flat CPU `[vocab]` logits, same as `ForwardLogits`.
|
||||||
|
/// Image embeddings stay device-resident — they're never copied
|
||||||
|
/// to CPU. The "tensors don't escape the worker" invariant holds.
|
||||||
|
ForwardLogitsWithImages {
|
||||||
|
handle: ArchHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
images: Vec<ImageInput>,
|
||||||
|
image_token_id: u32,
|
||||||
|
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||||
|
},
|
||||||
|
/// Encode one image through the model's vision tower. Stage A5 of
|
||||||
|
/// the vision plan (`doc/vision-qwen3_6-spec.md`).
|
||||||
|
///
|
||||||
|
/// `pixels` is the CPU-side preprocessed image tensor in row-major
|
||||||
|
/// `(C, H, W)` f32 layout — what `harness::preprocess::preprocess`
|
||||||
|
/// produces. `c`, `h`, `w` carry the shape since `Vec<f32>` itself
|
||||||
|
/// is rank-1. The handler reconstructs the tensor on the worker's
|
||||||
|
/// device, runs `VisionTower::forward`, copies the resulting
|
||||||
|
/// `(N_lm_tokens, hidden_size)` embedding back to CPU as a flat
|
||||||
|
/// `Vec<f32>` (the caller knows the expected shape from
|
||||||
|
/// `VisionTower::lm_tokens_for(h, w) * hidden_size`).
|
||||||
|
///
|
||||||
|
/// Mirrors the `ForwardLogits` "tensors don't escape" invariant —
|
||||||
|
/// device-side image embeddings are dropped at handler return.
|
||||||
|
/// Stage B will introduce a follow-up variant that keeps the
|
||||||
|
/// embeddings device-resident and references them from the next
|
||||||
|
/// `ForwardLogits` call, avoiding the round-trip copy.
|
||||||
|
EncodeImage {
|
||||||
|
handle: ArchHandle,
|
||||||
|
pixels: Vec<f32>,
|
||||||
|
c: usize,
|
||||||
|
h: usize,
|
||||||
|
w: usize,
|
||||||
|
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||||
|
},
|
||||||
/// Initialize the leader's NCCL communicator. The worker's
|
/// Initialize the leader's NCCL communicator. The worker's
|
||||||
/// `NcclState` mints the `Comm` here so its underlying
|
/// `NcclState` mints the `Comm` here so its underlying
|
||||||
/// `ncclComm_t` and `CudaContext` live on the same thread as
|
/// `ncclComm_t` and `CudaContext` live on the same thread as
|
||||||
@@ -117,6 +192,17 @@ pub enum Job {
|
|||||||
NcclSanity {
|
NcclSanity {
|
||||||
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
|
reply: oneshot::Sender<crate::harness::tp::rpc::WorkerResponse>,
|
||||||
},
|
},
|
||||||
|
/// Hand a clonable handle to the leader's NCCL `Comm` back to the
|
||||||
|
/// async side, so the TP step watchdog can call `ncclCommAbort` on
|
||||||
|
/// it from a *different* thread to unblock a wedged collective
|
||||||
|
/// (#17 Stage 2). Fetched once at init while the worker thread is
|
||||||
|
/// still responsive — a thread already wedged in a collective can't
|
||||||
|
/// service this job, which is exactly why the handle is cached
|
||||||
|
/// up front. Replies `None` before `NcclInit` has run.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
GetLeaderComm {
|
||||||
|
reply: oneshot::Sender<Option<crate::harness::tp::nccl_state::SendComm>>,
|
||||||
|
},
|
||||||
/// Load the leader's TP shard on the worker thread. The dispatch
|
/// Load the leader's TP shard on the worker thread. The dispatch
|
||||||
/// handler reads `state.nccl.comm()` directly (no cross-thread
|
/// handler reads `state.nccl.comm()` directly (no cross-thread
|
||||||
/// `Arc<Comm>` transfer, no `SendComm` wrapper) and builds the
|
/// `Arc<Comm>` transfer, no `SendComm` wrapper) and builds the
|
||||||
@@ -161,6 +247,24 @@ pub enum Job {
|
|||||||
offset: usize,
|
offset: usize,
|
||||||
reply: oneshot::Sender<Result<Vec<f32>>>,
|
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||||
},
|
},
|
||||||
|
/// Image-bearing leader (rank 0) forward for the single-shot vision
|
||||||
|
/// prefill. The handler preprocesses each `image_data_uris` entry
|
||||||
|
/// (the same deterministic path every rank runs), encodes through
|
||||||
|
/// the leader's replicated tower, splices at `image_token_id`, and
|
||||||
|
/// returns CPU-side `[vocab]` logits. Image tensors never escape the
|
||||||
|
/// worker thread. Caller fans out `GenerateStepWithImages` to the
|
||||||
|
/// subprocess ranks and drains them; only the leader forward moves
|
||||||
|
/// here.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
TpForwardLogitsWithImages {
|
||||||
|
handle: TpHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
image_token_id: u32,
|
||||||
|
image_data_uris: Vec<String>,
|
||||||
|
chunk_size: usize,
|
||||||
|
reply: oneshot::Sender<Result<Vec<f32>>>,
|
||||||
|
},
|
||||||
/// Tell the worker to break its dispatch loop and exit. Any jobs
|
/// Tell the worker to break its dispatch loop and exit. Any jobs
|
||||||
/// queued after this in the channel reply `Err` to their oneshot
|
/// queued after this in the channel reply `Err` to their oneshot
|
||||||
/// senders (the senders are dropped on the worker's exit, which
|
/// senders (the senders are dropped on the worker's exit, which
|
||||||
|
|||||||
@@ -161,6 +161,27 @@ impl DeviceWorkerHandle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fetch a clonable handle to the leader's NCCL `Comm` (#17 Stage 2).
|
||||||
|
/// The TP step watchdog caches this at init so it can call
|
||||||
|
/// `ncclCommAbort` from the async thread to unblock a wedged
|
||||||
|
/// collective. Returns `None` if uninitialised, poisoned, or gone —
|
||||||
|
/// the caller treats a missing handle as "can't abort" and logs it.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
pub async fn get_leader_comm(&self) -> Option<crate::harness::tp::nccl_state::SendComm> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
if self
|
||||||
|
.tx
|
||||||
|
.send(Job::GetLeaderComm { reply: reply_tx })
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
reply_rx.await.ok().flatten()
|
||||||
|
}
|
||||||
|
|
||||||
/// Load a GGUF (pre-quantized) single-GPU model on the worker
|
/// Load a GGUF (pre-quantized) single-GPU model on the worker
|
||||||
/// thread. The hf-hub resolution happens on the async caller; the
|
/// thread. The hf-hub resolution happens on the async caller; the
|
||||||
/// resolved local `gguf_path` plus the spec's model_id are sent
|
/// resolved local `gguf_path` plus the spec's model_id are sent
|
||||||
@@ -313,6 +334,90 @@ impl DeviceWorkerHandle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Forward with image-aware splicing in one round-trip. Stage B3.
|
||||||
|
///
|
||||||
|
/// Encodes each image on the worker thread (device-resident), then
|
||||||
|
/// runs the LM forward with the embeddings spliced at
|
||||||
|
/// `image_token_id` positions. Returns CPU `[vocab]` logits, same
|
||||||
|
/// shape as `forward_logits`. Image embeddings never copy back to
|
||||||
|
/// CPU.
|
||||||
|
pub async fn forward_logits_with_images(
|
||||||
|
&self,
|
||||||
|
handle: ArchHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
images: Vec<crate::harness::device_worker::jobs::ImageInput>,
|
||||||
|
image_token_id: u32,
|
||||||
|
) -> Result<Vec<f32>, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::ForwardLogitsWithImages {
|
||||||
|
handle,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
images,
|
||||||
|
image_token_id,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encode a preprocessed image through the model's vision tower
|
||||||
|
/// and return the resulting LM-side image embeddings as a
|
||||||
|
/// flattened CPU `Vec<f32>`. Stage A5.
|
||||||
|
///
|
||||||
|
/// `pixels` is the row-major `(c, h, w)` f32 image —
|
||||||
|
/// `harness::preprocess::preprocess` produces this exact shape.
|
||||||
|
/// The caller knows the expected output length from
|
||||||
|
/// `VisionTower::lm_tokens_for(h, w) * hidden_size` and reshapes
|
||||||
|
/// accordingly.
|
||||||
|
pub async fn encode_image(
|
||||||
|
&self,
|
||||||
|
handle: ArchHandle,
|
||||||
|
pixels: Vec<f32>,
|
||||||
|
c: usize,
|
||||||
|
h: usize,
|
||||||
|
w: usize,
|
||||||
|
) -> Result<Vec<f32>, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::EncodeImage {
|
||||||
|
handle,
|
||||||
|
pixels,
|
||||||
|
c,
|
||||||
|
h,
|
||||||
|
w,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Initialise the leader's NCCL communicator. The reply uses
|
/// Initialise the leader's NCCL communicator. The reply uses
|
||||||
/// `WorkerResponse` (same shape subprocess workers use over stdio
|
/// `WorkerResponse` (same shape subprocess workers use over stdio
|
||||||
/// RPC) so `WorkerPool::init_nccl`'s aggregation treats leader +
|
/// RPC) so `WorkerPool::init_nccl`'s aggregation treats leader +
|
||||||
@@ -488,6 +593,50 @@ impl DeviceWorkerHandle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Image-bearing TP leader forward (single-shot vision prefill).
|
||||||
|
/// Routes `Job::TpForwardLogitsWithImages` onto the worker thread;
|
||||||
|
/// the handler preprocesses + encodes + splices + forwards and
|
||||||
|
/// returns CPU-side `[vocab]` logits. The `WorkerPool` fans the
|
||||||
|
/// matching `GenerateStepWithImages` out to subprocess ranks so the
|
||||||
|
/// row-parallel collectives complete.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub async fn tp_forward_logits_with_images(
|
||||||
|
&self,
|
||||||
|
handle: TpHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
image_token_id: u32,
|
||||||
|
image_data_uris: Vec<String>,
|
||||||
|
chunk_size: usize,
|
||||||
|
) -> Result<Vec<f32>, WorkerError> {
|
||||||
|
if self.poisoned.load(Ordering::Acquire) {
|
||||||
|
return Err(WorkerError::Poisoned {
|
||||||
|
device_index: self.device_index,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
self.tx
|
||||||
|
.send(Job::TpForwardLogitsWithImages {
|
||||||
|
handle,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
image_token_id,
|
||||||
|
image_data_uris,
|
||||||
|
chunk_size,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.map_err(|_| WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
})?;
|
||||||
|
match reply_rx.await {
|
||||||
|
Ok(result) => result.map_err(WorkerError::from),
|
||||||
|
Err(_) => Err(WorkerError::Gone {
|
||||||
|
device_index: self.device_index,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Send `Job::Shutdown` and join the thread. Idempotent — calling
|
/// Send `Job::Shutdown` and join the thread. Idempotent — calling
|
||||||
/// twice is a no-op the second time.
|
/// twice is a no-op the second time.
|
||||||
pub fn shutdown(&self) -> anyhow::Result<()> {
|
pub fn shutdown(&self) -> anyhow::Result<()> {
|
||||||
@@ -569,6 +718,37 @@ mod tests {
|
|||||||
handle.shutdown().expect("shutdown ok");
|
handle.shutdown().expect("shutdown ok");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Stage A5: confirm the EncodeImage job round-trips through the
|
||||||
|
/// worker channel. We don't have a real loaded model in the slab
|
||||||
|
/// here, so the dispatch handler returns the
|
||||||
|
/// "no model for handle" error — which is exactly what we want to
|
||||||
|
/// see, since it proves the message routed through the channel
|
||||||
|
/// and reached the handler. Real-weights validation lives in the
|
||||||
|
/// Stage A7 / Stage B post-deploy smoke on beast.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn encode_image_routes_to_dispatch_and_errors_on_unknown_handle() {
|
||||||
|
use crate::harness::device_worker::jobs::ArchHandle;
|
||||||
|
|
||||||
|
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||||
|
let fake_arch = ArchHandle(99_999);
|
||||||
|
// (3, 4, 4) fake image — minimal payload, gets reconstructed
|
||||||
|
// on the worker before the handler errors out on the unknown
|
||||||
|
// ArchHandle lookup.
|
||||||
|
let pixels = vec![0.0_f32; 3 * 4 * 4];
|
||||||
|
let result = handle.encode_image(fake_arch, pixels, 3, 4, 4).await;
|
||||||
|
match result {
|
||||||
|
Err(WorkerError::Job(e)) => {
|
||||||
|
let msg = format!("{e:#}");
|
||||||
|
assert!(
|
||||||
|
msg.contains("EncodeImage: no model for handle"),
|
||||||
|
"expected unknown-handle error, got: {msg}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
other => panic!("expected Job(Err), got {other:?}"),
|
||||||
|
}
|
||||||
|
handle.shutdown().expect("shutdown ok");
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn shutdown_drains_pending_jobs() {
|
async fn shutdown_drains_pending_jobs() {
|
||||||
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
let handle = DeviceWorkerHandle::spawn(0).expect("spawn ok");
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ pub mod candle;
|
|||||||
pub mod chat_template;
|
pub mod chat_template;
|
||||||
pub mod device_worker;
|
pub mod device_worker;
|
||||||
pub mod preflight;
|
pub mod preflight;
|
||||||
|
pub mod preprocess;
|
||||||
pub mod tp;
|
pub mod tp;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
@@ -113,10 +114,8 @@ impl HarnessRegistry {
|
|||||||
for config in configs {
|
for config in configs {
|
||||||
match config.name.as_str() {
|
match config.name.as_str() {
|
||||||
"candle" => {
|
"candle" => {
|
||||||
let harness = Arc::new(candle::CandleHarness::new(
|
let harness =
|
||||||
bind_url.to_string(),
|
candle::CandleHarness::new(bind_url.to_string(), &settings.candle);
|
||||||
settings.candle.hf_cache.clone(),
|
|
||||||
));
|
|
||||||
registry.candle = Some(Arc::clone(&harness));
|
registry.candle = Some(Arc::clone(&harness));
|
||||||
registry.harnesses.insert("candle".into(), harness);
|
registry.harnesses.insert("candle".into(), harness);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@
|
|||||||
//! cleanly when Phase 1 lands.
|
//! cleanly when Phase 1 lands.
|
||||||
|
|
||||||
use cortex_core::harness::ModelSpec;
|
use cortex_core::harness::ModelSpec;
|
||||||
|
use cortex_core::source::ModelSourceId;
|
||||||
use hf_hub::api::tokio::Api;
|
use hf_hub::api::tokio::Api;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
@@ -115,13 +116,22 @@ pub enum PreflightError {
|
|||||||
/// One network round-trip (`repo.info()`); no blob fetches. Returns
|
/// One network round-trip (`repo.info()`); no blob fetches. Returns
|
||||||
/// `Ok(PlacementPlan)` when the requested combination is feasible, or
|
/// `Ok(PlacementPlan)` when the requested combination is feasible, or
|
||||||
/// a structured `PreflightError` describing what's wrong.
|
/// a structured `PreflightError` describing what's wrong.
|
||||||
pub async fn preflight(api: &Api, spec: &ModelSpec) -> Result<PlacementPlan, PreflightError> {
|
///
|
||||||
let repo = api.model(spec.model_id.clone());
|
/// `api` must already be configured for the scheme `source_id` belongs
|
||||||
|
/// to — caller (typically `CandleHarness::load_model`) builds it via
|
||||||
|
/// `hf_api_for(&source_id.scheme)`. Only the `org/name` portion of the
|
||||||
|
/// id is sent to the registry.
|
||||||
|
pub async fn preflight(
|
||||||
|
api: &Api,
|
||||||
|
source_id: &ModelSourceId,
|
||||||
|
spec: &ModelSpec,
|
||||||
|
) -> Result<PlacementPlan, PreflightError> {
|
||||||
|
let repo = api.model(source_id.repo_path());
|
||||||
let info = repo
|
let info = repo
|
||||||
.info()
|
.info()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| PreflightError::RepoFetchFailed {
|
.map_err(|e| PreflightError::RepoFetchFailed {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: source_id.to_string(),
|
||||||
cause: format!("{e}"),
|
cause: format!("{e}"),
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
@@ -132,13 +142,13 @@ pub async fn preflight(api: &Api, spec: &ModelSpec) -> Result<PlacementPlan, Pre
|
|||||||
match (&format, tp_size, spec.quant.as_deref()) {
|
match (&format, tp_size, spec.quant.as_deref()) {
|
||||||
// No weights at all — nothing to do.
|
// No weights at all — nothing to do.
|
||||||
(SourceFormat::Empty, _, _) => Err(PreflightError::EmptyRepo {
|
(SourceFormat::Empty, _, _) => Err(PreflightError::EmptyRepo {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: source_id.to_string(),
|
||||||
}),
|
}),
|
||||||
|
|
||||||
// GGUF-only + TP: not supported. Today's HauhauCS failure.
|
// GGUF-only + TP: not supported. Today's HauhauCS failure.
|
||||||
(SourceFormat::Gguf { quants }, tp, _) if tp > 1 => {
|
(SourceFormat::Gguf { quants }, tp, _) if tp > 1 => {
|
||||||
Err(PreflightError::TpRequiresSafetensors {
|
Err(PreflightError::TpRequiresSafetensors {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: source_id.to_string(),
|
||||||
tp_size: tp,
|
tp_size: tp,
|
||||||
gguf_quants: quants.clone(),
|
gguf_quants: quants.clone(),
|
||||||
suggestion: format!(
|
suggestion: format!(
|
||||||
@@ -154,13 +164,13 @@ pub async fn preflight(api: &Api, spec: &ModelSpec) -> Result<PlacementPlan, Pre
|
|||||||
let picked = pick_gguf_file(&filenames, requested.unwrap_or(""));
|
let picked = pick_gguf_file(&filenames, requested.unwrap_or(""));
|
||||||
match picked {
|
match picked {
|
||||||
Some(fname) => Ok(PlacementPlan {
|
Some(fname) => Ok(PlacementPlan {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: source_id.to_string(),
|
||||||
format: format.clone(),
|
format: format.clone(),
|
||||||
tp_size,
|
tp_size,
|
||||||
picked_quant_file: Some(fname),
|
picked_quant_file: Some(fname),
|
||||||
}),
|
}),
|
||||||
None => Err(PreflightError::QuantNotFound {
|
None => Err(PreflightError::QuantNotFound {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: source_id.to_string(),
|
||||||
requested: requested.unwrap_or("").to_string(),
|
requested: requested.unwrap_or("").to_string(),
|
||||||
available: quants.clone(),
|
available: quants.clone(),
|
||||||
nearest: nearest_quant(requested.unwrap_or(""), quants),
|
nearest: nearest_quant(requested.unwrap_or(""), quants),
|
||||||
@@ -174,7 +184,7 @@ pub async fn preflight(api: &Api, spec: &ModelSpec) -> Result<PlacementPlan, Pre
|
|||||||
// on disk, since it needs the parsed JSON.
|
// on disk, since it needs the parsed JSON.
|
||||||
(SourceFormat::DenseSafetensors { .. } | SourceFormat::Mixed { .. }, _, _) => {
|
(SourceFormat::DenseSafetensors { .. } | SourceFormat::Mixed { .. }, _, _) => {
|
||||||
Ok(PlacementPlan {
|
Ok(PlacementPlan {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: source_id.to_string(),
|
||||||
format: format.clone(),
|
format: format.clone(),
|
||||||
tp_size,
|
tp_size,
|
||||||
picked_quant_file: None,
|
picked_quant_file: None,
|
||||||
@@ -431,14 +441,20 @@ mod tests {
|
|||||||
format: &SourceFormat,
|
format: &SourceFormat,
|
||||||
filenames: &[&str],
|
filenames: &[&str],
|
||||||
) -> Result<PlacementPlan, PreflightError> {
|
) -> Result<PlacementPlan, PreflightError> {
|
||||||
|
// Tests parse spec.model_id with the default scheme so the
|
||||||
|
// assertions can keep comparing against bare "org/name".
|
||||||
|
let source_id: ModelSourceId = spec
|
||||||
|
.model_id
|
||||||
|
.parse::<ModelSourceId>()
|
||||||
|
.expect("test spec.model_id must parse");
|
||||||
let tp_size = spec.tensor_parallel.unwrap_or(1);
|
let tp_size = spec.tensor_parallel.unwrap_or(1);
|
||||||
match (format, tp_size, spec.quant.as_deref()) {
|
match (format, tp_size, spec.quant.as_deref()) {
|
||||||
(SourceFormat::Empty, _, _) => Err(PreflightError::EmptyRepo {
|
(SourceFormat::Empty, _, _) => Err(PreflightError::EmptyRepo {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: source_id.to_string(),
|
||||||
}),
|
}),
|
||||||
(SourceFormat::Gguf { quants }, tp, _) if tp > 1 => {
|
(SourceFormat::Gguf { quants }, tp, _) if tp > 1 => {
|
||||||
Err(PreflightError::TpRequiresSafetensors {
|
Err(PreflightError::TpRequiresSafetensors {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: source_id.to_string(),
|
||||||
tp_size: tp,
|
tp_size: tp,
|
||||||
gguf_quants: quants.clone(),
|
gguf_quants: quants.clone(),
|
||||||
suggestion: format!(
|
suggestion: format!(
|
||||||
@@ -451,13 +467,13 @@ mod tests {
|
|||||||
let picked = pick_gguf_file(filenames, requested.unwrap_or(""));
|
let picked = pick_gguf_file(filenames, requested.unwrap_or(""));
|
||||||
match picked {
|
match picked {
|
||||||
Some(fname) => Ok(PlacementPlan {
|
Some(fname) => Ok(PlacementPlan {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: source_id.to_string(),
|
||||||
format: format.clone(),
|
format: format.clone(),
|
||||||
tp_size,
|
tp_size,
|
||||||
picked_quant_file: Some(fname),
|
picked_quant_file: Some(fname),
|
||||||
}),
|
}),
|
||||||
None => Err(PreflightError::QuantNotFound {
|
None => Err(PreflightError::QuantNotFound {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: source_id.to_string(),
|
||||||
requested: requested.unwrap_or("").to_string(),
|
requested: requested.unwrap_or("").to_string(),
|
||||||
available: quants.clone(),
|
available: quants.clone(),
|
||||||
nearest: nearest_quant(requested.unwrap_or(""), quants),
|
nearest: nearest_quant(requested.unwrap_or(""), quants),
|
||||||
@@ -466,7 +482,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
(SourceFormat::DenseSafetensors { .. } | SourceFormat::Mixed { .. }, _, _) => {
|
(SourceFormat::DenseSafetensors { .. } | SourceFormat::Mixed { .. }, _, _) => {
|
||||||
Ok(PlacementPlan {
|
Ok(PlacementPlan {
|
||||||
model_id: spec.model_id.clone(),
|
model_id: source_id.to_string(),
|
||||||
format: format.clone(),
|
format: format.clone(),
|
||||||
tp_size,
|
tp_size,
|
||||||
picked_quant_file: None,
|
picked_quant_file: None,
|
||||||
|
|||||||
441
crates/neuron/src/harness/preprocess.rs
Normal file
441
crates/neuron/src/harness/preprocess.rs
Normal file
@@ -0,0 +1,441 @@
|
|||||||
|
//! Image preprocessing for vision-capable models.
|
||||||
|
//!
|
||||||
|
//! Decodes `data:image/...;base64,...` URIs from OpenAI-style
|
||||||
|
//! `image_url` content parts into the patch tensors a candle vision
|
||||||
|
//! tower expects. Resolution is **dynamic** (#14): each image is
|
||||||
|
//! resized to its native aspect via Qwen `smart_resize` — a
|
||||||
|
//! factor-aligned `(h, w)` whose pixel count lands in the profile's
|
||||||
|
//! `[min_pixels, max_pixels]` budget — so the LM token count varies per
|
||||||
|
//! image (`(h/factor) × (w/factor)`).
|
||||||
|
//!
|
||||||
|
//! Spec reference: `doc/vision-qwen3_6-spec.md` — preprocessor
|
||||||
|
//! section.
|
||||||
|
//!
|
||||||
|
//! Normalisation: pixel value `p ∈ [0, 255]` becomes
|
||||||
|
//! `(p/255 - mean) / std`. Qwen3.6's preprocessor_config.json
|
||||||
|
//! specifies `image_mean = image_std = [0.5, 0.5, 0.5]`, which
|
||||||
|
//! simplifies to `2p/255 - 1` mapping `[0,255]` → `[-1, 1]`. We
|
||||||
|
//! still parameterise mean/std so the same code generalises to other
|
||||||
|
//! VL families (Qwen2-VL uses imagenet stats, for instance).
|
||||||
|
//!
|
||||||
|
//! Pipeline (per image):
|
||||||
|
//! 1. data: URI → base64 decode → bytes
|
||||||
|
//! 2. bytes → image::DynamicImage (PNG/JPEG/WebP/etc)
|
||||||
|
//! 3. smart_resize to a native-aspect, factor-aligned H×W (pixel space)
|
||||||
|
//! 4. RGB→f32, normalise per mean/std
|
||||||
|
//! 5. layout to (C, H, W) tensor
|
||||||
|
//!
|
||||||
|
//! Patchification (cutting the HxW tensor into `patch_size` blocks)
|
||||||
|
//! happens inside the vision tower's `patch_embed` conv, so this
|
||||||
|
//! module stops at "preprocessed RGB f32 tensor."
|
||||||
|
|
||||||
|
use anyhow::{Context, Result, anyhow};
|
||||||
|
use base64::Engine;
|
||||||
|
use image::DynamicImage;
|
||||||
|
use image::imageops::FilterType;
|
||||||
|
|
||||||
|
/// Preprocessing target. Captures the resize policy (Qwen `smart_resize`
|
||||||
|
/// factor + pixel budget) and the channel-wise normalisation constants
|
||||||
|
/// from the model's `preprocessor_config.json`. Images are resized to
|
||||||
|
/// their **native aspect** — a factor-aligned `(h, w)` whose pixel count
|
||||||
|
/// lands in `[min_pixels, max_pixels]` — not a fixed square (#14).
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PreprocessProfile {
|
||||||
|
/// Both output dims are multiples of this. For Qwen3.6 it is
|
||||||
|
/// `patch_size(16) × spatial_merge_size(2) = 32`, so the post-merge
|
||||||
|
/// LM grid is exactly `(h/factor, w/factor)`.
|
||||||
|
pub factor: u32,
|
||||||
|
/// Lower pixel bound — tiny images are upscaled to at least this.
|
||||||
|
pub min_pixels: u32,
|
||||||
|
/// Upper pixel bound — large images are downscaled to at most this.
|
||||||
|
/// Caps per-image LM tokens (`max_pixels / factor²`) and the
|
||||||
|
/// O(patches²) ViT attention cost.
|
||||||
|
pub max_pixels: u32,
|
||||||
|
pub image_mean: [f32; 3],
|
||||||
|
pub image_std: [f32; 3],
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The Qwen3.6 vision tower rejects any image whose **patch** count
|
||||||
|
/// exceeds its learned pos-embed budget (`num_position_embeddings =
|
||||||
|
/// 2304 = 48²`; see `vision.rs`). At `patch_size = 16` that is
|
||||||
|
/// `2304 × 16² = 589_824` source pixels. `max_pixels` is hard-capped to
|
||||||
|
/// this so `smart_resize` can never produce an over-budget grid — a
|
||||||
|
/// per-rank "patch count exceeds pos_embed budget" error mid-TP-forward
|
||||||
|
/// would otherwise poison the device context. The pos-embed grid is the
|
||||||
|
/// resolution Qwen3.6 was trained at, so this cap is principled, not just
|
||||||
|
/// defensive.
|
||||||
|
const QWEN3_6_MAX_PIXELS_CAP: u32 = 2304 * 16 * 16; // 589_824 → ≤ 2304 patches → ≤ 576 LM tokens
|
||||||
|
|
||||||
|
/// Default pixel budget for Qwen3.6: `256²` (64 LM tokens) up to the
|
||||||
|
/// pos-embed cap (576 LM tokens). Generous for documents/OCR, bounded
|
||||||
|
/// for serving. Operators lower it with `NEURON_VISION_MIN_PIXELS` /
|
||||||
|
/// `NEURON_VISION_MAX_PIXELS` (the upper bound is still clamped to the
|
||||||
|
/// cap above — raising it past the budget would poison the model).
|
||||||
|
const QWEN3_6_MIN_PIXELS: u32 = 65_536;
|
||||||
|
|
||||||
|
fn env_pixels(name: &str, default: u32) -> u32 {
|
||||||
|
std::env::var(name)
|
||||||
|
.ok()
|
||||||
|
.and_then(|v| v.trim().parse::<u32>().ok())
|
||||||
|
.unwrap_or(default)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PreprocessProfile {
|
||||||
|
/// Profile for Qwen3.6. Native-aspect `smart_resize` (factor 32),
|
||||||
|
/// normalise to `[-1, 1]` via mean=std=0.5. Pixel budget defaults to
|
||||||
|
/// [`QWEN3_6_MIN_PIXELS`]…[`QWEN3_6_MAX_PIXELS_CAP`], overridable via
|
||||||
|
/// `NEURON_VISION_MIN_PIXELS` / `NEURON_VISION_MAX_PIXELS`. Clamped
|
||||||
|
/// sane: `factor² ≤ min ≤ max`, and `max ≤` the pos-embed cap (so the
|
||||||
|
/// vision tower never rejects a resized image and poisons the context).
|
||||||
|
pub fn qwen3_6() -> Self {
|
||||||
|
let factor = 32u32;
|
||||||
|
let f2 = factor * factor;
|
||||||
|
let min_pixels = env_pixels("NEURON_VISION_MIN_PIXELS", QWEN3_6_MIN_PIXELS)
|
||||||
|
.max(f2)
|
||||||
|
.min(QWEN3_6_MAX_PIXELS_CAP);
|
||||||
|
let max_pixels = env_pixels("NEURON_VISION_MAX_PIXELS", QWEN3_6_MAX_PIXELS_CAP)
|
||||||
|
.min(QWEN3_6_MAX_PIXELS_CAP)
|
||||||
|
.max(min_pixels);
|
||||||
|
Self {
|
||||||
|
factor,
|
||||||
|
min_pixels,
|
||||||
|
max_pixels,
|
||||||
|
image_mean: [0.5, 0.5, 0.5],
|
||||||
|
image_std: [0.5, 0.5, 0.5],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The factor-aligned `(h, w)` this profile would resize a source
|
||||||
|
/// `src_h × src_w` image to. Pure integer policy — no pixel work.
|
||||||
|
pub fn resized_dims(&self, src_h: u32, src_w: u32) -> Result<(u32, u32)> {
|
||||||
|
smart_resize(src_h, src_w, self.factor, self.min_pixels, self.max_pixels)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Qwen `smart_resize`: the smallest `factor`-aligned `(h_bar, w_bar)`
|
||||||
|
/// that preserves aspect ratio as closely as possible while keeping the
|
||||||
|
/// pixel count within `[min_pixels, max_pixels]`. Direct port of the
|
||||||
|
/// canonical Qwen2-VL / Qwen3-VL image-processor function (so neuron's
|
||||||
|
/// grid matches what the model was trained on).
|
||||||
|
///
|
||||||
|
/// Returns `(height, width)`. Errors if the aspect ratio exceeds 200:1
|
||||||
|
/// (degenerate input — a 1-pixel-tall strip), matching upstream.
|
||||||
|
pub fn smart_resize(
|
||||||
|
height: u32,
|
||||||
|
width: u32,
|
||||||
|
factor: u32,
|
||||||
|
min_pixels: u32,
|
||||||
|
max_pixels: u32,
|
||||||
|
) -> Result<(u32, u32)> {
|
||||||
|
let h = height.max(1) as f64;
|
||||||
|
let w = width.max(1) as f64;
|
||||||
|
let ratio = h.max(w) / h.min(w);
|
||||||
|
if ratio > 200.0 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"image aspect ratio {ratio:.1}:1 exceeds the 200:1 limit ({height}×{width}); \
|
||||||
|
refusing to resize"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let f = factor as f64;
|
||||||
|
let (minp, maxp) = (min_pixels as f64, max_pixels as f64);
|
||||||
|
// round-to-nearest-factor (may be 0 for sub-factor inputs; the
|
||||||
|
// min-pixels branch below grows it back up).
|
||||||
|
let mut h_bar = (h / f).round() * f;
|
||||||
|
let mut w_bar = (w / f).round() * f;
|
||||||
|
if h_bar * w_bar > maxp {
|
||||||
|
let beta = (h * w / maxp).sqrt();
|
||||||
|
h_bar = f.max((h / beta / f).floor() * f);
|
||||||
|
w_bar = f.max((w / beta / f).floor() * f);
|
||||||
|
} else if h_bar * w_bar < minp {
|
||||||
|
let beta = (minp / (h * w)).sqrt();
|
||||||
|
h_bar = (h * beta / f).ceil() * f;
|
||||||
|
w_bar = (w * beta / f).ceil() * f;
|
||||||
|
}
|
||||||
|
Ok((h_bar as u32, w_bar as u32))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode a `data:image/...;base64,...` URI into an in-memory image.
|
||||||
|
///
|
||||||
|
/// Accepts the OpenAI Chat Completions `image_url` shape — a string
|
||||||
|
/// URL with `data:` scheme and base64 payload. The MIME type is read
|
||||||
|
/// from the URI for diagnostics but `image::load_from_memory` sniffs
|
||||||
|
/// the format from the bytes themselves, so the MIME is advisory.
|
||||||
|
///
|
||||||
|
/// Bare `http(s)://` URLs are explicitly rejected here — fetching
|
||||||
|
/// them from a vision-model server is a fingerprintable behaviour
|
||||||
|
/// (server-side request forgery, infinite recursion if the URL
|
||||||
|
/// points at the gateway itself, etc.). Clients that want remote
|
||||||
|
/// images can fetch them and pass base64 themselves.
|
||||||
|
pub fn decode_data_uri(uri: &str) -> Result<DynamicImage> {
|
||||||
|
let after_scheme = uri
|
||||||
|
.strip_prefix("data:")
|
||||||
|
.ok_or_else(|| anyhow!("image_url must use data: scheme; got {uri:.40}…"))?;
|
||||||
|
let (meta, payload) = after_scheme
|
||||||
|
.split_once(',')
|
||||||
|
.ok_or_else(|| anyhow!("malformed data URI: missing ',' separator"))?;
|
||||||
|
if !meta.contains(";base64") {
|
||||||
|
anyhow::bail!(
|
||||||
|
"data URI must use base64 encoding (got '{meta}'); raw URL-encoded payloads not supported"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let bytes = base64::engine::general_purpose::STANDARD
|
||||||
|
.decode(payload.trim())
|
||||||
|
.context("base64-decode image data URI payload")?;
|
||||||
|
image::load_from_memory(&bytes).context("decode image bytes (PNG/JPEG/WebP/etc)")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resize and normalise an image into a `(3, H, W)` row-major
|
||||||
|
/// `Vec<f32>` ready to hand to the vision tower's `patch_embed`
|
||||||
|
/// conv.
|
||||||
|
///
|
||||||
|
/// Uses bilinear resampling — Qwen2-VL's reference uses bicubic, but
|
||||||
|
/// bilinear is what the candle ecosystem standardises on and is
|
||||||
|
/// faster on CPU. Quality difference is marginal for downstream
|
||||||
|
/// vision-encoder consumption. The numerical-validation issue (#15)
|
||||||
|
/// will quantify any discrepancy.
|
||||||
|
pub fn preprocess(img: &DynamicImage, profile: &PreprocessProfile) -> Result<(Vec<f32>, u32, u32)> {
|
||||||
|
let (h_bar, w_bar) = profile.resized_dims(img.height(), img.width())?;
|
||||||
|
let rgb = img
|
||||||
|
.resize_exact(w_bar, h_bar, FilterType::Triangle)
|
||||||
|
.to_rgb8();
|
||||||
|
let h = h_bar as usize;
|
||||||
|
let w = w_bar as usize;
|
||||||
|
let mut out = vec![0.0_f32; 3 * h * w];
|
||||||
|
// Row-major (C, H, W). Candle's Conv2d expects NCHW, so this is
|
||||||
|
// the natural layout — the caller stacks `n` of these along the
|
||||||
|
// batch axis as needed.
|
||||||
|
for c in 0..3 {
|
||||||
|
let mean = profile.image_mean[c];
|
||||||
|
let std = profile.image_std[c];
|
||||||
|
for y in 0..h {
|
||||||
|
for x in 0..w {
|
||||||
|
let pixel = rgb.get_pixel(x as u32, y as u32);
|
||||||
|
let raw = pixel[c] as f32 / 255.0;
|
||||||
|
out[c * h * w + y * w + x] = (raw - mean) / std;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((out, h_bar, w_bar))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Combined helper: decode + preprocess in one call. Returns the
|
||||||
|
/// `(3, h, w)` row-major pixels plus the resized `(h, w)` — the caller
|
||||||
|
/// needs the dims to build the tensor and to derive the LM token grid
|
||||||
|
/// `(h/factor, w/factor)`. Most call sites use this; the two-step path
|
||||||
|
/// exists for callers (tests, future video preprocessing) that need the
|
||||||
|
/// intermediate `DynamicImage`.
|
||||||
|
pub fn preprocess_data_uri(uri: &str, profile: &PreprocessProfile) -> Result<(Vec<f32>, u32, u32)> {
|
||||||
|
let img = decode_data_uri(uri)?;
|
||||||
|
preprocess(&img, profile)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resized `(h, w)` for a data-URI image **without** running the pixel
|
||||||
|
/// normalisation — decode header + `smart_resize` only. Lets a caller
|
||||||
|
/// that just needs the LM token count (e.g. the TP leader expanding the
|
||||||
|
/// prompt) avoid materialising the full pixel tensor twice.
|
||||||
|
pub fn resized_dims_for_uri(uri: &str, profile: &PreprocessProfile) -> Result<(u32, u32)> {
|
||||||
|
let img = decode_data_uri(uri)?;
|
||||||
|
profile.resized_dims(img.height(), img.width())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use image::{ImageBuffer, Rgb};
|
||||||
|
|
||||||
|
/// A 1×1 red PNG, hand-built. Matches the well-known smallest
|
||||||
|
/// valid PNG we use in tests/curl examples elsewhere.
|
||||||
|
const ONE_BY_ONE_RED_PNG_B64: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==";
|
||||||
|
|
||||||
|
fn red_png_uri() -> String {
|
||||||
|
format!("data:image/png;base64,{ONE_BY_ONE_RED_PNG_B64}")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decodes_well_formed_png_data_uri() {
|
||||||
|
let img = decode_data_uri(&red_png_uri()).expect("decode 1x1 png");
|
||||||
|
assert_eq!(img.width(), 1);
|
||||||
|
assert_eq!(img.height(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_non_data_scheme() {
|
||||||
|
let err = decode_data_uri("https://example.com/cat.jpg")
|
||||||
|
.expect_err("http(s) URLs must be rejected");
|
||||||
|
assert!(format!("{err:#}").contains("data:"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_malformed_uri_without_comma() {
|
||||||
|
let err = decode_data_uri("data:image/png;base64").unwrap_err();
|
||||||
|
assert!(format!("{err:#}").contains("','"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_non_base64_payload() {
|
||||||
|
let err = decode_data_uri("data:image/png,raw-bytes-here").unwrap_err();
|
||||||
|
assert!(format!("{err:#}").contains("base64"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_bad_base64_payload() {
|
||||||
|
let err = decode_data_uri("data:image/png;base64,not!valid!base64!").unwrap_err();
|
||||||
|
assert!(format!("{err:#}").contains("base64"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_garbage_image_bytes() {
|
||||||
|
// Valid base64 ("Hello World!"), invalid image bytes.
|
||||||
|
let err = decode_data_uri("data:image/png;base64,SGVsbG8gV29ybGQh").unwrap_err();
|
||||||
|
assert!(
|
||||||
|
format!("{err:#}").contains("decode image"),
|
||||||
|
"should fail at image-decode step"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preprocess_red_image_produces_correct_shape_and_values() {
|
||||||
|
let profile = PreprocessProfile::qwen3_6();
|
||||||
|
// Build a tiny pure-red image directly, skipping data: URI
|
||||||
|
// decoding so this test isolates the resize+normalise path.
|
||||||
|
let img: ImageBuffer<Rgb<u8>, Vec<u8>> = ImageBuffer::from_pixel(2, 2, Rgb([255, 0, 0]));
|
||||||
|
let dyn_img = DynamicImage::ImageRgb8(img);
|
||||||
|
let (out, h_bar, w_bar) = preprocess(&dyn_img, &profile).expect("preprocess");
|
||||||
|
|
||||||
|
let h = h_bar as usize;
|
||||||
|
let w = w_bar as usize;
|
||||||
|
assert_eq!(out.len(), 3 * h * w);
|
||||||
|
// Dims are factor-aligned and at least the min-pixel floor.
|
||||||
|
assert_eq!(h_bar % profile.factor, 0);
|
||||||
|
assert_eq!(w_bar % profile.factor, 0);
|
||||||
|
assert!(h * w >= profile.min_pixels as usize);
|
||||||
|
// After mean=0.5, std=0.5: red channel (255/255=1.0) → (1.0 - 0.5)/0.5 = 1.0
|
||||||
|
// green/blue (0.0) → (0.0 - 0.5)/0.5 = -1.0
|
||||||
|
assert!(
|
||||||
|
(out[0] - 1.0).abs() < 1e-5,
|
||||||
|
"R[0] should be 1.0, got {}",
|
||||||
|
out[0]
|
||||||
|
);
|
||||||
|
assert!((out[h * w] - (-1.0)).abs() < 1e-5, "G[0] should be -1.0");
|
||||||
|
assert!(
|
||||||
|
(out[2 * h * w] - (-1.0)).abs() < 1e-5,
|
||||||
|
"B[0] should be -1.0"
|
||||||
|
);
|
||||||
|
// All values are finite
|
||||||
|
assert!(out.iter().all(|v| v.is_finite()), "no NaN/Inf in output");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preprocess_data_uri_end_to_end() {
|
||||||
|
let profile = PreprocessProfile::qwen3_6();
|
||||||
|
let (out, h, w) = preprocess_data_uri(&red_png_uri(), &profile).expect("e2e preprocess");
|
||||||
|
assert_eq!(out.len(), 3 * h as usize * w as usize);
|
||||||
|
assert!(out.iter().all(|v| v.is_finite()));
|
||||||
|
// resized_dims_for_uri agrees with the full preprocess.
|
||||||
|
let (h2, w2) = resized_dims_for_uri(&red_png_uri(), &profile).expect("dims");
|
||||||
|
assert_eq!((h, w), (h2, w2));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preprocess_grayscale_image_promotes_to_rgb() {
|
||||||
|
let profile = PreprocessProfile::qwen3_6();
|
||||||
|
// 1x1 grayscale = 200 → after conversion to RGB, all three
|
||||||
|
// channels equal 200, normalised → (200/255 - 0.5)/0.5 ≈ 0.569
|
||||||
|
let gray = DynamicImage::ImageLuma8(ImageBuffer::from_pixel(1, 1, image::Luma([200])));
|
||||||
|
let (out, h_bar, w_bar) = preprocess(&gray, &profile).expect("preprocess");
|
||||||
|
let expected = ((200.0 / 255.0) - 0.5) / 0.5;
|
||||||
|
let h = h_bar as usize;
|
||||||
|
let w = w_bar as usize;
|
||||||
|
for c in 0..3 {
|
||||||
|
let v = out[c * h * w];
|
||||||
|
assert!(
|
||||||
|
(v - expected).abs() < 1e-3,
|
||||||
|
"channel {c}: expected {expected}, got {v}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn smart_resize_keeps_factor_aligned_square_in_budget() {
|
||||||
|
// 448×448 sits inside [65536, 1048576] and is factor-aligned →
|
||||||
|
// unchanged. (Regression guard for the old fixed-res sweet spot.)
|
||||||
|
let (h, w) = smart_resize(448, 448, 32, 65_536, 1_048_576).unwrap();
|
||||||
|
assert_eq!((h, w), (448, 448));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn smart_resize_preserves_aspect_and_caps_at_max() {
|
||||||
|
// 3000×4000 (landscape) → downscaled under max_pixels, aspect kept.
|
||||||
|
let (h, w) = smart_resize(3000, 4000, 32, 65_536, 1_048_576).unwrap();
|
||||||
|
assert_eq!(h % 32, 0);
|
||||||
|
assert_eq!(w % 32, 0);
|
||||||
|
assert!(
|
||||||
|
(h as u64) * (w as u64) <= 1_048_576,
|
||||||
|
"must respect max_pixels"
|
||||||
|
);
|
||||||
|
assert!(w > h, "landscape orientation preserved");
|
||||||
|
// aspect ≈ 4000/3000 = 1.333; allow a factor-rounding tolerance.
|
||||||
|
let ar = w as f64 / h as f64;
|
||||||
|
assert!((ar - 4.0 / 3.0).abs() < 0.15, "aspect ~4:3, got {ar:.3}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn smart_resize_floors_tiny_image_at_min() {
|
||||||
|
// 16×16 → upscaled to at least min_pixels, factor-aligned.
|
||||||
|
let (h, w) = smart_resize(16, 16, 32, 65_536, 1_048_576).unwrap();
|
||||||
|
assert_eq!(h % 32, 0);
|
||||||
|
assert_eq!(w % 32, 0);
|
||||||
|
assert!((h as u64) * (w as u64) >= 65_536, "must respect min_pixels");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn smart_resize_tall_nonsquare_stays_nonsquare() {
|
||||||
|
// A tall screenshot keeps portrait orientation.
|
||||||
|
let (h, w) = smart_resize(2000, 500, 32, 65_536, 1_048_576).unwrap();
|
||||||
|
assert!(h > w, "portrait orientation preserved");
|
||||||
|
assert_eq!(h % 32, 0);
|
||||||
|
assert_eq!(w % 32, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn smart_resize_rejects_extreme_aspect() {
|
||||||
|
let err = smart_resize(1, 500, 32, 65_536, 1_048_576).unwrap_err();
|
||||||
|
assert!(format!("{err:#}").contains("200:1"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn qwen3_6_never_exceeds_pos_embed_patch_budget() {
|
||||||
|
// The pos-embed cap must hold for huge, tall, wide, and extreme
|
||||||
|
// images — exceeding 2304 patches errors mid-tower and poisons
|
||||||
|
// the device context, so this invariant is load-bearing.
|
||||||
|
let p = PreprocessProfile::qwen3_6();
|
||||||
|
for (sh, sw) in [
|
||||||
|
(8000u32, 6000u32),
|
||||||
|
(808, 1600),
|
||||||
|
(4000, 400),
|
||||||
|
(1, 199),
|
||||||
|
(16, 16),
|
||||||
|
] {
|
||||||
|
let (h, w) = p.resized_dims(sh, sw).unwrap();
|
||||||
|
let patches = (h / 16) * (w / 16);
|
||||||
|
assert!(
|
||||||
|
patches <= 2304,
|
||||||
|
"{sh}x{sw} → {h}x{w} = {patches} patches exceeds the 2304 budget"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn qwen3_6_default_budget_bounds_lm_tokens() {
|
||||||
|
// A huge source image caps at max_pixels → the per-image LM token
|
||||||
|
// count stays within budget (so it can't blow NEURON_MAX_PROMPT_TOKENS).
|
||||||
|
let p = PreprocessProfile::qwen3_6();
|
||||||
|
let (h, w) = p.resized_dims(8000, 6000).unwrap();
|
||||||
|
let lm_tokens = (h / p.factor) * (w / p.factor);
|
||||||
|
let budget = p.max_pixels / (p.factor * p.factor);
|
||||||
|
assert!(
|
||||||
|
lm_tokens <= budget,
|
||||||
|
"max-res image LM tokens {lm_tokens} must stay within budget {budget}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
154
crates/neuron/src/harness/testdata/qwen3_6_chat_template.jinja
vendored
Normal file
154
crates/neuron/src/harness/testdata/qwen3_6_chat_template.jinja
vendored
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
{%- set image_count = namespace(value=0) %}
|
||||||
|
{%- set video_count = namespace(value=0) %}
|
||||||
|
{%- macro render_content(content, do_vision_count, is_system_content=false) %}
|
||||||
|
{%- if content is string %}
|
||||||
|
{{- content }}
|
||||||
|
{%- elif content is iterable and content is not mapping %}
|
||||||
|
{%- for item in content %}
|
||||||
|
{%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
|
||||||
|
{%- if is_system_content %}
|
||||||
|
{{- raise_exception('System message cannot contain images.') }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if do_vision_count %}
|
||||||
|
{%- set image_count.value = image_count.value + 1 %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if add_vision_id %}
|
||||||
|
{{- 'Picture ' ~ image_count.value ~ ': ' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<|vision_start|><|image_pad|><|vision_end|>' }}
|
||||||
|
{%- elif 'video' in item or item.type == 'video' %}
|
||||||
|
{%- if is_system_content %}
|
||||||
|
{{- raise_exception('System message cannot contain videos.') }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if do_vision_count %}
|
||||||
|
{%- set video_count.value = video_count.value + 1 %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if add_vision_id %}
|
||||||
|
{{- 'Video ' ~ video_count.value ~ ': ' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<|vision_start|><|video_pad|><|vision_end|>' }}
|
||||||
|
{%- elif 'text' in item %}
|
||||||
|
{{- item.text }}
|
||||||
|
{%- else %}
|
||||||
|
{{- raise_exception('Unexpected item type in content.') }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- elif content is none or content is undefined %}
|
||||||
|
{{- '' }}
|
||||||
|
{%- else %}
|
||||||
|
{{- raise_exception('Unexpected content type.') }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endmacro %}
|
||||||
|
{%- if not messages %}
|
||||||
|
{{- raise_exception('No messages provided.') }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if tools and tools is iterable and tools is not mapping %}
|
||||||
|
{{- '<|im_start|>system\n' }}
|
||||||
|
{{- "# Tools\n\nYou have access to the following functions:\n\n<tools>" }}
|
||||||
|
{%- for tool in tools %}
|
||||||
|
{{- "\n" }}
|
||||||
|
{{- tool | tojson }}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- "\n</tools>" }}
|
||||||
|
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
|
||||||
|
{%- if messages[0].role == 'system' %}
|
||||||
|
{%- set content = render_content(messages[0].content, false, true)|trim %}
|
||||||
|
{%- if content %}
|
||||||
|
{{- '\n\n' + content }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- else %}
|
||||||
|
{%- if messages[0].role == 'system' %}
|
||||||
|
{%- set content = render_content(messages[0].content, false, true)|trim %}
|
||||||
|
{{- '<|im_start|>system\n' + content + '<|im_end|>\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
||||||
|
{%- for message in messages[::-1] %}
|
||||||
|
{%- set index = (messages|length - 1) - loop.index0 %}
|
||||||
|
{%- if ns.multi_step_tool and message.role == "user" %}
|
||||||
|
{%- set content = render_content(message.content, false)|trim %}
|
||||||
|
{%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %}
|
||||||
|
{%- set ns.multi_step_tool = false %}
|
||||||
|
{%- set ns.last_query_index = index %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- if ns.multi_step_tool %}
|
||||||
|
{{- raise_exception('No user query found in messages.') }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- for message in messages %}
|
||||||
|
{%- set content = render_content(message.content, true)|trim %}
|
||||||
|
{%- if message.role == "system" %}
|
||||||
|
{%- if not loop.first %}
|
||||||
|
{{- raise_exception('System message must be at the beginning.') }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- elif message.role == "user" %}
|
||||||
|
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
||||||
|
{%- elif message.role == "assistant" %}
|
||||||
|
{%- set reasoning_content = '' %}
|
||||||
|
{%- if message.reasoning_content is string %}
|
||||||
|
{%- set reasoning_content = message.reasoning_content %}
|
||||||
|
{%- else %}
|
||||||
|
{%- if '</think>' in content %}
|
||||||
|
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
||||||
|
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- set reasoning_content = reasoning_content|trim %}
|
||||||
|
{%- if (preserve_thinking is defined and preserve_thinking is true) or (loop.index0 > ns.last_query_index) %}
|
||||||
|
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }}
|
||||||
|
{%- else %}
|
||||||
|
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %}
|
||||||
|
{%- for tool_call in message.tool_calls %}
|
||||||
|
{%- if tool_call.function is defined %}
|
||||||
|
{%- set tool_call = tool_call.function %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if loop.first %}
|
||||||
|
{%- if content|trim %}
|
||||||
|
{{- '\n\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
||||||
|
{%- else %}
|
||||||
|
{{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- else %}
|
||||||
|
{{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if tool_call.arguments is defined %}
|
||||||
|
{%- for args_name, args_value in tool_call.arguments|items %}
|
||||||
|
{{- '<parameter=' + args_name + '>\n' }}
|
||||||
|
{%- set args_value = args_value | string if args_value is string else args_value | tojson | safe %}
|
||||||
|
{{- args_value }}
|
||||||
|
{{- '\n</parameter>\n' }}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '</function>\n</tool_call>' }}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- elif message.role == "tool" %}
|
||||||
|
{%- if loop.previtem and loop.previtem.role != "tool" %}
|
||||||
|
{{- '<|im_start|>user' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '\n<tool_response>\n' }}
|
||||||
|
{{- content }}
|
||||||
|
{{- '\n</tool_response>' }}
|
||||||
|
{%- if not loop.last and loop.nextitem.role != "tool" %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- elif loop.last %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- else %}
|
||||||
|
{{- raise_exception('Unexpected message role.') }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{- '<|im_start|>assistant\n' }}
|
||||||
|
{%- if enable_thinking is defined and enable_thinking is false %}
|
||||||
|
{{- '<think>\n\n</think>\n\n' }}
|
||||||
|
{%- else %}
|
||||||
|
{{- '<think>\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
@@ -62,6 +62,30 @@ impl TpLeaderModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Chunked image prefill on rank 0. Only the vision-capable
|
||||||
|
/// `qwen3_5` arch supports it; the dense `qwen3` arch has no tower.
|
||||||
|
pub fn prefill_with_images_chunked(
|
||||||
|
&mut self,
|
||||||
|
tokens: &[u32],
|
||||||
|
base_offset: usize,
|
||||||
|
image_pixels: &[candle_core::Tensor],
|
||||||
|
image_token_id: u32,
|
||||||
|
chunk_size: usize,
|
||||||
|
) -> candle_core::Result<candle_core::Tensor> {
|
||||||
|
match self {
|
||||||
|
TpLeaderModel::Qwen3_5(m) => m.prefill_with_images_chunked(
|
||||||
|
tokens,
|
||||||
|
base_offset,
|
||||||
|
image_pixels,
|
||||||
|
image_token_id,
|
||||||
|
chunk_size,
|
||||||
|
),
|
||||||
|
TpLeaderModel::Qwen3(_) => {
|
||||||
|
candle_core::bail!("prefill_with_images_chunked: qwen3 (dense) has no vision tower")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn clear_kv_cache(&mut self) {
|
pub fn clear_kv_cache(&mut self) {
|
||||||
match self {
|
match self {
|
||||||
TpLeaderModel::Qwen3(m) => m.clear_kv_cache(),
|
TpLeaderModel::Qwen3(m) => m.clear_kv_cache(),
|
||||||
@@ -221,9 +245,67 @@ pub struct WorkerPool {
|
|||||||
/// Phase 4 the load itself moves onto the worker and that bridge
|
/// Phase 4 the load itself moves onto the worker and that bridge
|
||||||
/// goes away.
|
/// goes away.
|
||||||
pub(crate) leader_worker: std::sync::Arc<super::device_worker::DeviceWorkerHandle>,
|
pub(crate) leader_worker: std::sync::Arc<super::device_worker::DeviceWorkerHandle>,
|
||||||
|
/// Cached handle to the leader's NCCL `Comm`, fetched at `init_nccl`
|
||||||
|
/// while the worker thread is responsive. The TP step watchdog uses
|
||||||
|
/// it to `ncclCommAbort` a wedged collective from the async thread —
|
||||||
|
/// the one NCCL op allowed concurrently with an in-flight collective,
|
||||||
|
/// and the only way to unblock the in-process leader thread so
|
||||||
|
/// recovery's `unload` doesn't itself hang (#17 Stage 2). `None` if
|
||||||
|
/// init couldn't cache it; the watchdog then logs that it can't abort.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
leader_comm: Option<nccl_state::SendComm>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Per-step deadline for a TP forward (#17 Stage 2). A healthy decode
|
||||||
|
/// step or chunked prefill completes in well under a second; a wedged
|
||||||
|
/// NCCL collective never returns. Generous default so no legitimate step
|
||||||
|
/// trips it; overridable via `NEURON_TP_STEP_TIMEOUT_S` (seconds).
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn tp_step_timeout() -> std::time::Duration {
|
||||||
|
let secs = std::env::var("NEURON_TP_STEP_TIMEOUT_S")
|
||||||
|
.ok()
|
||||||
|
.and_then(|v| v.trim().parse::<u64>().ok())
|
||||||
|
.filter(|&s| s > 0)
|
||||||
|
.unwrap_or(120);
|
||||||
|
std::time::Duration::from_secs(secs)
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WorkerPool {
|
impl WorkerPool {
|
||||||
|
/// Abort the leader's NCCL comm to unblock a collective the watchdog
|
||||||
|
/// found wedged (#17 Stage 2). Logs the whole sequence loudly so a
|
||||||
|
/// real-world hang leaves a greppable forensic trail
|
||||||
|
/// (`tp watchdog:` / `ncclCommAbort`). Calling abort from this async
|
||||||
|
/// thread while the worker thread is blocked inside the collective is
|
||||||
|
/// the one concurrent NCCL op the library sanctions — it is how a
|
||||||
|
/// stuck/failed collective is unblocked.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn watchdog_abort_leader_comm(&self, model_id: &str, secs: u64) {
|
||||||
|
tracing::error!(
|
||||||
|
model = %model_id,
|
||||||
|
timeout_s = secs,
|
||||||
|
"tp watchdog: leader forward exceeded deadline — NCCL collective wedged; \
|
||||||
|
aborting comm to unblock the leader thread for auto-recovery"
|
||||||
|
);
|
||||||
|
match &self.leader_comm {
|
||||||
|
Some(c) => match c.0.abort() {
|
||||||
|
Ok(()) => tracing::error!(
|
||||||
|
model = %model_id,
|
||||||
|
"tp watchdog: ncclCommAbort succeeded — wedged collective unblocked; \
|
||||||
|
failing the step so the model auto-recovers (unload+reload)"
|
||||||
|
),
|
||||||
|
Err(e) => tracing::error!(
|
||||||
|
model = %model_id, error = ?e,
|
||||||
|
"tp watchdog: ncclCommAbort failed — recovery may stall until a process restart"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
None => tracing::error!(
|
||||||
|
model = %model_id,
|
||||||
|
"tp watchdog: no cached leader comm handle — cannot abort; recovery will rely \
|
||||||
|
on a process restart"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Spawn `world_size - 1` worker subprocesses. Rank 0 is the
|
/// Spawn `world_size - 1` worker subprocesses. Rank 0 is the
|
||||||
/// leader (in-process) and is *not* spawned here — the leader
|
/// leader (in-process) and is *not* spawned here — the leader
|
||||||
/// holds rank 0's NCCL Comm and shard in its own address space.
|
/// holds rank 0's NCCL Comm and shard in its own address space.
|
||||||
@@ -300,6 +382,8 @@ impl WorkerPool {
|
|||||||
workers,
|
workers,
|
||||||
exe,
|
exe,
|
||||||
leader_worker,
|
leader_worker,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
leader_comm: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -380,6 +464,23 @@ impl WorkerPool {
|
|||||||
world_size = self.world_size,
|
world_size = self.world_size,
|
||||||
"NCCL communicator established across all ranks"
|
"NCCL communicator established across all ranks"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Cache the leader's Comm handle now, while the worker thread is
|
||||||
|
// responsive, so the TP step watchdog can abort a wedged
|
||||||
|
// collective later (it can't fetch it then — the thread is stuck).
|
||||||
|
// (#17 Stage 2.)
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
self.leader_comm = self.leader_worker.get_leader_comm().await;
|
||||||
|
if self.leader_comm.is_some() {
|
||||||
|
tracing::debug!("cached leader NCCL comm handle for the TP step watchdog");
|
||||||
|
} else {
|
||||||
|
tracing::warn!(
|
||||||
|
"could not cache leader NCCL comm handle; the TP step watchdog will be \
|
||||||
|
unable to abort a wedged collective (a hang would need a process restart)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -604,10 +705,27 @@ impl WorkerPool {
|
|||||||
// that's the invariant the whole refactor exists to
|
// that's the invariant the whole refactor exists to
|
||||||
// preserve.
|
// preserve.
|
||||||
let leader_start = std::time::Instant::now();
|
let leader_start = std::time::Instant::now();
|
||||||
let leader_result = self
|
let timeout = tp_step_timeout();
|
||||||
|
let leader_fut = self
|
||||||
.leader_worker
|
.leader_worker
|
||||||
.tp_forward_logits(leader_handle, tokens, offset)
|
.tp_forward_logits(leader_handle, tokens, offset);
|
||||||
.await;
|
let leader_result = match tokio::time::timeout(timeout, leader_fut).await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(_elapsed) => {
|
||||||
|
// Watchdog (#17 Stage 2): the NCCL collective is wedged.
|
||||||
|
// Abort the leader comm to unblock its thread, then fail
|
||||||
|
// the step WITHOUT draining (the subprocess workers are
|
||||||
|
// wedged too; recovery's unload kills them). The error
|
||||||
|
// poisons the model → auto-recovery, which no longer hangs
|
||||||
|
// because the leader thread is now responsive.
|
||||||
|
self.watchdog_abort_leader_comm(model_id, timeout.as_secs());
|
||||||
|
anyhow::bail!(
|
||||||
|
"tp watchdog: leader forward exceeded {}s deadline; aborted wedged NCCL \
|
||||||
|
comm — model will auto-recover",
|
||||||
|
timeout.as_secs()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
let leader_ok = leader_result.is_ok();
|
let leader_ok = leader_result.is_ok();
|
||||||
let leader_ms = leader_start.elapsed().as_millis();
|
let leader_ms = leader_start.elapsed().as_millis();
|
||||||
// Surface the leader's own error at WARN before draining
|
// Surface the leader's own error at WARN before draining
|
||||||
@@ -687,6 +805,146 @@ impl WorkerPool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Image-bearing variant of [`Self::generate_step`] for the
|
||||||
|
/// single-shot vision prefill. Identical fan-out / leader-forward /
|
||||||
|
/// drain shape, but every rank runs the encode + splice path:
|
||||||
|
///
|
||||||
|
/// - subprocess workers get `GenerateStepWithImages` (carrying the
|
||||||
|
/// source `image_data_uris`); each preprocesses + encodes through
|
||||||
|
/// its replicated tower and splices locally;
|
||||||
|
/// - the leader runs the same encode + splice + forward on its
|
||||||
|
/// device worker thread via `tp_forward_logits_with_images`.
|
||||||
|
///
|
||||||
|
/// The row-parallel `AllReduce`s synchronise the ranks exactly as in
|
||||||
|
/// the text path. Because the tower is replicated and the preprocess
|
||||||
|
/// is deterministic, every rank's spliced hidden state matches — no
|
||||||
|
/// embedding broadcast. Only used for prefill; decode reuses
|
||||||
|
/// `generate_step`.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub async fn generate_step_with_images(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
leader_handle: super::device_worker::TpHandle,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
image_token_id: u32,
|
||||||
|
image_data_uris: Vec<String>,
|
||||||
|
chunk_size: usize,
|
||||||
|
) -> Result<Vec<f32>> {
|
||||||
|
let step_start = std::time::Instant::now();
|
||||||
|
let tokens_len = tokens.len();
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
tokens = tokens_len,
|
||||||
|
offset,
|
||||||
|
images = image_data_uris.len(),
|
||||||
|
chunk_size,
|
||||||
|
"WorkerPool::generate_step_with_images: fan-out"
|
||||||
|
);
|
||||||
|
|
||||||
|
// 1. Fan-out the image-bearing prefill to subprocess workers.
|
||||||
|
for w in &mut self.workers {
|
||||||
|
w.send_only(&WorkerRequest::GenerateStepWithImages {
|
||||||
|
model_id: model_id.to_string(),
|
||||||
|
tokens: tokens.clone(),
|
||||||
|
offset,
|
||||||
|
image_token_id,
|
||||||
|
image_data_uris: image_data_uris.clone(),
|
||||||
|
chunk_size,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Leader's image forward on its device worker thread. The
|
||||||
|
// AllReduce CustomOps block until every worker issues the
|
||||||
|
// matching collective; CPU-side logits keep the device tensor
|
||||||
|
// from escaping the worker thread.
|
||||||
|
let leader_start = std::time::Instant::now();
|
||||||
|
let timeout = tp_step_timeout();
|
||||||
|
let leader_fut = self.leader_worker.tp_forward_logits_with_images(
|
||||||
|
leader_handle,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
image_token_id,
|
||||||
|
image_data_uris,
|
||||||
|
chunk_size,
|
||||||
|
);
|
||||||
|
let leader_result = match tokio::time::timeout(timeout, leader_fut).await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(_elapsed) => {
|
||||||
|
// Watchdog (#17 Stage 2) — see generate_step. Vision
|
||||||
|
// prefill is still well under the deadline on healthy
|
||||||
|
// hardware; a timeout means a wedged collective.
|
||||||
|
self.watchdog_abort_leader_comm(model_id, timeout.as_secs());
|
||||||
|
anyhow::bail!(
|
||||||
|
"tp watchdog: leader image forward exceeded {}s deadline; aborted wedged \
|
||||||
|
NCCL comm — model will auto-recover",
|
||||||
|
timeout.as_secs()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let leader_ok = leader_result.is_ok();
|
||||||
|
let leader_ms = leader_start.elapsed().as_millis();
|
||||||
|
if !leader_ok {
|
||||||
|
let detail = leader_result
|
||||||
|
.as_ref()
|
||||||
|
.err()
|
||||||
|
.map(|e| format!("{e:#}"))
|
||||||
|
.unwrap_or_default();
|
||||||
|
tracing::warn!(
|
||||||
|
model = %model_id,
|
||||||
|
tokens = tokens_len,
|
||||||
|
offset,
|
||||||
|
leader_ms,
|
||||||
|
error = %detail,
|
||||||
|
"WorkerPool::generate_step_with_images: leader forward failed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. ALWAYS drain worker responses, regardless of the leader's
|
||||||
|
// outcome, so stale GenerateStepOk replies don't poison the
|
||||||
|
// next request's recv (same invariant as generate_step).
|
||||||
|
let worker_errors = drain_workers(&mut self.workers, |r| match r {
|
||||||
|
WorkerResponse::GenerateStepOk => Ok(()),
|
||||||
|
WorkerResponse::Error { kind, message } => Err(format!("[{kind}]: {message}")),
|
||||||
|
other => Err(format!("expected GenerateStepOk, got {other:?}")),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
tracing::debug!(
|
||||||
|
model = %model_id,
|
||||||
|
leader_ms,
|
||||||
|
leader_ok,
|
||||||
|
errors = worker_errors.len(),
|
||||||
|
total_ms = step_start.elapsed().as_millis(),
|
||||||
|
"WorkerPool::generate_step_with_images: workers drained"
|
||||||
|
);
|
||||||
|
|
||||||
|
match leader_result {
|
||||||
|
Ok(values) => {
|
||||||
|
if worker_errors.is_empty() {
|
||||||
|
Ok(values)
|
||||||
|
} else {
|
||||||
|
anyhow::bail!(
|
||||||
|
"GenerateStepWithImages: leader succeeded but workers failed: {}",
|
||||||
|
worker_errors.join("; ")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
if worker_errors.is_empty() {
|
||||||
|
Err(anyhow::Error::new(e)
|
||||||
|
.context("GenerateStepWithImages: leader forward failed"))
|
||||||
|
} else {
|
||||||
|
Err(anyhow::Error::new(e).context(format!(
|
||||||
|
"GenerateStepWithImages: leader forward failed and workers also failed: {}",
|
||||||
|
worker_errors.join("; ")
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Reset the KV cache for `model_id` on every rank. Called at the
|
/// Reset the KV cache for `model_id` on every rank. Called at the
|
||||||
/// start of every inference so a fresh request doesn't attend over
|
/// start of every inference so a fresh request doesn't attend over
|
||||||
/// the previous one's tokens.
|
/// the previous one's tokens.
|
||||||
|
|||||||
@@ -119,40 +119,25 @@ mod cuda_impl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `Arc<Comm>` doesn't impl `Send` because `Comm` wraps a raw
|
/// Thin newtype over `Arc<Comm>`, kept for call-site clarity — it marks
|
||||||
/// `ncclComm_t` pointer. The NCCL contract is "operations against a
|
/// the points where a comm handle is intentionally moved across threads
|
||||||
/// given comm must be serialised", not "the handle must stay on the
|
/// (e.g. cached async-side for the TP step watchdog's `ncclCommAbort`).
|
||||||
/// thread that created it" — so it's safe to move an `Arc<Comm>`
|
|
||||||
/// across threads as long as no concurrent ops are issued. The
|
|
||||||
/// pool's outer Mutex serialises us into `spawn_blocking`, so this
|
|
||||||
/// wrapper at the move boundary is the only thing missing.
|
|
||||||
///
|
///
|
||||||
/// `Sync` is also marked safe because the `Arc<Comm>` clones held
|
/// `Send`/`Sync` are provided upstream by `cudarc`'s `Comm` (which
|
||||||
/// by the row-parallel layers are only used from the
|
/// asserts the NCCL thread-safety invariant, including aborting from a
|
||||||
/// `spawn_blocking` thread driving the forward pass; concurrent
|
/// different thread than one inside a collective), so this type derives
|
||||||
/// access from another thread would still be a bug.
|
/// them automatically — no manual `unsafe impl` here.
|
||||||
pub struct SendComm(pub Arc<Comm>);
|
pub struct SendComm(pub Arc<Comm>);
|
||||||
|
|
||||||
// SAFETY: see the doc-comment above; the invariant is enforced at
|
|
||||||
// the call site (pool Mutex + single spawn_blocking thread), not at
|
|
||||||
// the type level.
|
|
||||||
unsafe impl Send for SendComm {}
|
|
||||||
unsafe impl Sync for SendComm {}
|
|
||||||
|
|
||||||
impl SendComm {
|
impl SendComm {
|
||||||
pub fn into_inner(self) -> Arc<Comm> {
|
pub fn into_inner(self) -> Arc<Comm> {
|
||||||
self.0
|
self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SAFETY: `cudarc::nccl::Comm` contains a raw `ncclComm_t` pointer
|
// `NcclState`'s `Send`/`Sync` are auto-derived: its `Arc<Comm>` and
|
||||||
// (libnccl-allocated state). NCCL requires that operations against
|
// `Arc<CudaContext>` fields are now `Send`/`Sync` (cudarc asserts the
|
||||||
// one Comm be issued one at a time; we serialise access by storing
|
// comm thread-safety invariant), so no manual `unsafe impl` is needed.
|
||||||
// NcclState behind a Mutex in `WorkerPool`. The Comm itself is
|
|
||||||
// move-safe — NCCL doesn't track the calling OS thread, only the
|
|
||||||
// stream the operations are dispatched against.
|
|
||||||
unsafe impl Send for NcclState {}
|
|
||||||
unsafe impl Sync for NcclState {}
|
|
||||||
|
|
||||||
/// Generate a fresh NCCL `Id` and return it hex-encoded. Used by
|
/// Generate a fresh NCCL `Id` and return it hex-encoded. Used by
|
||||||
/// the leader to mint the shared communicator id which is then
|
/// the leader to mint the shared communicator id which is then
|
||||||
|
|||||||
@@ -88,6 +88,33 @@ pub enum WorkerRequest {
|
|||||||
offset: usize,
|
offset: usize,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/// Like `GenerateStep` but the prefill carries image content. Every
|
||||||
|
/// rank preprocesses the same `image_data_uris` through its
|
||||||
|
/// *replicated* vision tower, splices the resulting patch embeddings
|
||||||
|
/// at `image_token_id` positions, and runs the forward — the
|
||||||
|
/// row-parallel `AllReduce`s still synchronise every rank. Because
|
||||||
|
/// the tower is replicated and `preprocess_data_uri` is
|
||||||
|
/// deterministic, the spliced hidden state is identical on every
|
||||||
|
/// rank, so no embedding broadcast is needed. Sent only for the
|
||||||
|
/// (single-shot) image-bearing prefill; decode steps use plain
|
||||||
|
/// `GenerateStep`. Worker replies with the same `GenerateStepOk`.
|
||||||
|
GenerateStepWithImages {
|
||||||
|
model_id: String,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
/// `<|image_pad|>` sentinel id (248056 for Qwen3.6); splice
|
||||||
|
/// target in the expanded token stream.
|
||||||
|
image_token_id: u32,
|
||||||
|
/// Source image data URIs (`data:image/...;base64,...`), one per
|
||||||
|
/// image in prompt order. Each rank decodes + preprocesses these
|
||||||
|
/// identically; tens of KB each, so cheap over the stdin pipe.
|
||||||
|
image_data_uris: Vec<String>,
|
||||||
|
/// Prefill chunk size (tokens). Sent explicitly so every rank
|
||||||
|
/// walks the prompt in identical windows and the per-chunk
|
||||||
|
/// row-parallel collectives stay paired across ranks.
|
||||||
|
chunk_size: usize,
|
||||||
|
},
|
||||||
|
|
||||||
/// Reset the KV cache for this model on this rank. Sent at the
|
/// Reset the KV cache for this model on this rank. Sent at the
|
||||||
/// start of every inference so a fresh request doesn't accidentally
|
/// start of every inference so a fresh request doesn't accidentally
|
||||||
/// attend over the previous one's tokens.
|
/// attend over the previous one's tokens.
|
||||||
@@ -191,6 +218,33 @@ mod tests {
|
|||||||
assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#);
|
assert_eq!(wire, r#"{"op":"init","comm_id":"deadbeef"}"#);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_generate_step_with_images_round_trip() {
|
||||||
|
let req = WorkerRequest::GenerateStepWithImages {
|
||||||
|
model_id: "Qwen/Qwen3.6-27B".into(),
|
||||||
|
tokens: vec![1, 2, 248056, 3],
|
||||||
|
offset: 0,
|
||||||
|
image_token_id: 248056,
|
||||||
|
image_data_uris: vec!["data:image/png;base64,AAA=".into()],
|
||||||
|
chunk_size: 512,
|
||||||
|
};
|
||||||
|
let wire = serde_json::to_string(&req).unwrap();
|
||||||
|
assert!(wire.contains(r#""op":"generate_step_with_images""#));
|
||||||
|
match roundtrip(&req) {
|
||||||
|
WorkerRequest::GenerateStepWithImages {
|
||||||
|
tokens,
|
||||||
|
image_token_id,
|
||||||
|
image_data_uris,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
assert_eq!(tokens, vec![1, 2, 248056, 3]);
|
||||||
|
assert_eq!(image_token_id, 248056);
|
||||||
|
assert_eq!(image_data_uris.len(), 1);
|
||||||
|
}
|
||||||
|
other => panic!("expected GenerateStepWithImages, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn request_shutdown_round_trip() {
|
fn request_shutdown_round_trip() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
|||||||
@@ -46,6 +46,8 @@ use super::tp_linear::{ColumnParallelLinear, RowParallelLinear};
|
|||||||
use crate::harness::arch::qwen3_5::linear_attn::repeat_interleave;
|
use crate::harness::arch::qwen3_5::linear_attn::repeat_interleave;
|
||||||
use crate::harness::arch::qwen3_5::rmsnorm::{Qwen3_5RmsNorm, Qwen3_5RmsNormGated, l2norm};
|
use crate::harness::arch::qwen3_5::rmsnorm::{Qwen3_5RmsNorm, Qwen3_5RmsNormGated, l2norm};
|
||||||
use crate::harness::arch::qwen3_5::rope::RotaryEmbedding;
|
use crate::harness::arch::qwen3_5::rope::RotaryEmbedding;
|
||||||
|
use crate::harness::arch::qwen3_5::splice_runs;
|
||||||
|
use crate::harness::arch::qwen3_5::vision::VisionTower;
|
||||||
pub use crate::harness::arch::qwen3_5::{Config, TextConfig};
|
pub use crate::harness::arch::qwen3_5::{Config, TextConfig};
|
||||||
|
|
||||||
// ─── linear-attention (Gated DeltaNet) ──────────────────────────────
|
// ─── linear-attention (Gated DeltaNet) ──────────────────────────────
|
||||||
@@ -524,7 +526,8 @@ impl TpQwen3_5Attention {
|
|||||||
&mut self,
|
&mut self,
|
||||||
x: &Tensor,
|
x: &Tensor,
|
||||||
attn_mask: Option<&Tensor>,
|
attn_mask: Option<&Tensor>,
|
||||||
offset: usize,
|
cos: &Tensor,
|
||||||
|
sin: &Tensor,
|
||||||
) -> candle_core::Result<Tensor> {
|
) -> candle_core::Result<Tensor> {
|
||||||
let (b, l, _) = x.dims3()?;
|
let (b, l, _) = x.dims3()?;
|
||||||
|
|
||||||
@@ -557,7 +560,7 @@ impl TpQwen3_5Attention {
|
|||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
.contiguous()?;
|
.contiguous()?;
|
||||||
|
|
||||||
let (q, k) = self.rotary.apply(&q, &k, offset)?;
|
let (q, k) = self.rotary.apply_cos_sin(&q, &k, cos, sin)?;
|
||||||
let (k, v) = self.kv_cache.append(&k, &v)?;
|
let (k, v) = self.kv_cache.append(&k, &v)?;
|
||||||
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
|
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
|
||||||
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
|
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
|
||||||
@@ -805,11 +808,12 @@ impl TpQwen3_5DecoderLayer {
|
|||||||
&mut self,
|
&mut self,
|
||||||
x: &Tensor,
|
x: &Tensor,
|
||||||
attn_mask: Option<&Tensor>,
|
attn_mask: Option<&Tensor>,
|
||||||
offset: usize,
|
cos: &Tensor,
|
||||||
|
sin: &Tensor,
|
||||||
) -> candle_core::Result<Tensor> {
|
) -> candle_core::Result<Tensor> {
|
||||||
let h = self.input_layernorm.forward(x)?;
|
let h = self.input_layernorm.forward(x)?;
|
||||||
let attn_out = match &mut self.attention {
|
let attn_out = match &mut self.attention {
|
||||||
TpAttentionKind::Full(attn) => attn.forward(&h, attn_mask, offset)?,
|
TpAttentionKind::Full(attn) => attn.forward(&h, attn_mask, cos, sin)?,
|
||||||
TpAttentionKind::Linear(net) => net.forward(&h)?,
|
TpAttentionKind::Linear(net) => net.forward(&h)?,
|
||||||
};
|
};
|
||||||
let x = (x + attn_out)?;
|
let x = (x + attn_out)?;
|
||||||
@@ -832,6 +836,15 @@ pub struct TpQwen3_5Model {
|
|||||||
embed_tokens: Embedding,
|
embed_tokens: Embedding,
|
||||||
layers: Vec<TpQwen3_5DecoderLayer>,
|
layers: Vec<TpQwen3_5DecoderLayer>,
|
||||||
norm: Qwen3_5RmsNorm,
|
norm: Qwen3_5RmsNorm,
|
||||||
|
/// Replicated rotary, shared with every full-attention layer. The
|
||||||
|
/// model builds the per-forward cos/sin (interleaved M-RoPE for image
|
||||||
|
/// tokens, plain for text) once and the layers apply it. Identical on
|
||||||
|
/// every rank, so per-rank position ids stay consistent.
|
||||||
|
rotary: Arc<RotaryEmbedding>,
|
||||||
|
/// `offset + rope_delta` is the text-axis decode position; set from
|
||||||
|
/// `get_rope_index` during a vision prefill, reset in `clear_kv_cache`.
|
||||||
|
/// See `Qwen3_5Model::rope_delta`.
|
||||||
|
rope_delta: i64,
|
||||||
device: Device,
|
device: Device,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
}
|
}
|
||||||
@@ -898,6 +911,8 @@ impl TpQwen3_5Model {
|
|||||||
embed_tokens,
|
embed_tokens,
|
||||||
layers,
|
layers,
|
||||||
norm,
|
norm,
|
||||||
|
rotary,
|
||||||
|
rope_delta: 0,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
})
|
})
|
||||||
@@ -954,6 +969,8 @@ impl TpQwen3_5Model {
|
|||||||
embed_tokens,
|
embed_tokens,
|
||||||
layers,
|
layers,
|
||||||
norm,
|
norm,
|
||||||
|
rotary,
|
||||||
|
rope_delta: 0,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
})
|
})
|
||||||
@@ -967,6 +984,14 @@ impl TpQwen3_5Model {
|
|||||||
for l in &mut self.layers {
|
for l in &mut self.layers {
|
||||||
l.clear_kv_cache();
|
l.clear_kv_cache();
|
||||||
}
|
}
|
||||||
|
self.rope_delta = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the decode `rope_delta` computed by `get_rope_index` during a
|
||||||
|
/// vision prefill, so decode after the image resumes text positions
|
||||||
|
/// from the image-compressed counter.
|
||||||
|
pub fn set_rope_delta(&mut self, delta: i64) {
|
||||||
|
self.rope_delta = delta;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
@@ -978,15 +1003,88 @@ impl TpQwen3_5Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
|
self.forward_inner(input, offset, None, None, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Forward for a vision-prefill chunk: optional image-embedding
|
||||||
|
/// splice plus explicit interleaved-M-RoPE `position_ids` (the
|
||||||
|
/// chunk's slice of the full prompt's 3D positions). Used by
|
||||||
|
/// `TpQwen3_5ForCausalLM::prefill_with_images_chunked`, which
|
||||||
|
/// computes the positions once over the whole prompt and slices them
|
||||||
|
/// per chunk so every rank steps in lockstep.
|
||||||
|
pub fn forward_with_positions(
|
||||||
|
&mut self,
|
||||||
|
input: &Tensor,
|
||||||
|
offset: usize,
|
||||||
|
position_ids: &Tensor,
|
||||||
|
image_embeds: Option<&Tensor>,
|
||||||
|
image_token_id: Option<u32>,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
self.forward_inner(
|
||||||
|
input,
|
||||||
|
offset,
|
||||||
|
image_embeds,
|
||||||
|
image_token_id,
|
||||||
|
Some(position_ids),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shared forward. Splices image embeddings at `image_token_id`
|
||||||
|
/// positions when present, then builds the rotary cos/sin — from the
|
||||||
|
/// explicit `position_ids` (interleaved M-RoPE, vision) when given,
|
||||||
|
/// else plain positions at `offset + rope_delta` (text / decode) —
|
||||||
|
/// and runs the sharded decoder stack. The TP replicated-hidden-state
|
||||||
|
/// invariant holds because every rank encodes the same pixels and
|
||||||
|
/// computes the same positions.
|
||||||
|
fn forward_inner(
|
||||||
|
&mut self,
|
||||||
|
input: &Tensor,
|
||||||
|
offset: usize,
|
||||||
|
image_embeds: Option<&Tensor>,
|
||||||
|
image_token_id: Option<u32>,
|
||||||
|
position_ids: Option<&Tensor>,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
let (b, l) = input.dims2()?;
|
let (b, l) = input.dims2()?;
|
||||||
let mut h = self.embed_tokens.forward(input)?;
|
let mut h = self.embed_tokens.forward(input)?;
|
||||||
|
|
||||||
|
if let (Some(img), Some(tok_id)) = (image_embeds, image_token_id) {
|
||||||
|
let ids: Vec<u32> = input.flatten_all()?.to_vec1()?;
|
||||||
|
let mut positions: Vec<u32> = Vec::with_capacity(img.dim(0)?);
|
||||||
|
for (idx, id) in ids.iter().enumerate() {
|
||||||
|
if *id == tok_id {
|
||||||
|
positions.push(idx as u32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let n_img_tokens = img.dim(0)?;
|
||||||
|
if positions.len() != n_img_tokens {
|
||||||
|
candle_core::bail!(
|
||||||
|
"TP forward: chunk has {} image-token positions but image_embeds carries \
|
||||||
|
{} tokens — patch-count expansion / chunk slicing mismatch",
|
||||||
|
positions.len(),
|
||||||
|
n_img_tokens,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if !positions.is_empty() {
|
||||||
|
let img = img.to_dtype(self.dtype)?;
|
||||||
|
h = splice_runs(&h, &img, &positions)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (cos, sin) = match position_ids {
|
||||||
|
Some(pos) => self.rotary.mrope_cos_sin(pos)?,
|
||||||
|
None => {
|
||||||
|
let base = (offset as i64 + self.rope_delta).max(0) as usize;
|
||||||
|
self.rotary.plain_cos_sin(base, l)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let causal = if l == 1 {
|
let causal = if l == 1 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(self.causal_mask(b, l, offset)?)
|
Some(self.causal_mask(b, l, offset)?)
|
||||||
};
|
};
|
||||||
for layer in &mut self.layers {
|
for layer in &mut self.layers {
|
||||||
h = layer.forward(&h, causal.as_ref(), offset)?;
|
h = layer.forward(&h, causal.as_ref(), &cos, &sin)?;
|
||||||
}
|
}
|
||||||
self.norm.forward(&h)
|
self.norm.forward(&h)
|
||||||
}
|
}
|
||||||
@@ -995,6 +1093,41 @@ impl TpQwen3_5Model {
|
|||||||
pub struct TpQwen3_5ForCausalLM {
|
pub struct TpQwen3_5ForCausalLM {
|
||||||
base: TpQwen3_5Model,
|
base: TpQwen3_5Model,
|
||||||
lm_head: super::tp_linear::MaybeQuantLinear,
|
lm_head: super::tp_linear::MaybeQuantLinear,
|
||||||
|
/// Replicated vision tower (TP-vision). Loaded on every rank from
|
||||||
|
/// the full, unsharded `model.visual.*` weights; `None` for
|
||||||
|
/// text-only checkpoints. Each rank encodes the same image
|
||||||
|
/// independently — no sharding, no broadcast — which keeps the
|
||||||
|
/// spliced input embeddings identical across ranks (the
|
||||||
|
/// replicated-hidden-state invariant the sharded layers rely on).
|
||||||
|
vision: Option<VisionTower>,
|
||||||
|
/// `<|image_pad|>` sentinel id (mirrors `Config::image_token_id`);
|
||||||
|
/// the splice target for `forward_with_vision`.
|
||||||
|
image_token_id: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load the replicated vision tower from the unsharded `model.visual.*`
|
||||||
|
/// weights when the config carries a `vision_config` block. Shared by
|
||||||
|
/// the cuda and non-cuda `load` variants. `vb.pp("model.visual")`
|
||||||
|
/// resolves against the same full safetensors every rank mmaps; plain
|
||||||
|
/// `.get()` on a `ShardedVarBuilder` returns the full (replicated)
|
||||||
|
/// tensor, so this loads identically regardless of `world_size`.
|
||||||
|
fn load_replicated_vision_tower(
|
||||||
|
config: &Config,
|
||||||
|
vb: &ShardedVarBuilder,
|
||||||
|
) -> Result<Option<VisionTower>> {
|
||||||
|
match config.vision_config.clone() {
|
||||||
|
Some(vcfg) => {
|
||||||
|
tracing::info!(
|
||||||
|
depth = vcfg.depth,
|
||||||
|
hidden_size = vcfg.hidden_size,
|
||||||
|
"loading qwen3_5 vision tower (TP replicated)"
|
||||||
|
);
|
||||||
|
let tower = VisionTower::load(vcfg, vb.pp("model.visual"))
|
||||||
|
.context("load qwen3_5 vision tower (model.visual.*) [TP replicated]")?;
|
||||||
|
Ok(Some(tower))
|
||||||
|
}
|
||||||
|
None => Ok(None),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TpQwen3_5ForCausalLM {
|
impl TpQwen3_5ForCausalLM {
|
||||||
@@ -1012,7 +1145,14 @@ impl TpQwen3_5ForCausalLM {
|
|||||||
let cfg = &config.text_config;
|
let cfg = &config.text_config;
|
||||||
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm, quant)?;
|
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, comm, quant)?;
|
||||||
let lm_head = build_lm_head(cfg, vb, &base, quant)?;
|
let lm_head = build_lm_head(cfg, vb, &base, quant)?;
|
||||||
let model = Self { base, lm_head };
|
let vision = load_replicated_vision_tower(&config, vb)?;
|
||||||
|
let image_token_id = config.image_token_id;
|
||||||
|
let model = Self {
|
||||||
|
base,
|
||||||
|
lm_head,
|
||||||
|
vision,
|
||||||
|
image_token_id,
|
||||||
|
};
|
||||||
log_construction_complete(cfg, rank, world_size, quant, model.device());
|
log_construction_complete(cfg, rank, world_size, quant, model.device());
|
||||||
Ok(model)
|
Ok(model)
|
||||||
}
|
}
|
||||||
@@ -1029,17 +1169,198 @@ impl TpQwen3_5ForCausalLM {
|
|||||||
let cfg = &config.text_config;
|
let cfg = &config.text_config;
|
||||||
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, quant)?;
|
let base = TpQwen3_5Model::load(cfg, vb, mmap, rank, world_size, quant)?;
|
||||||
let lm_head = build_lm_head(cfg, vb, &base, quant)?;
|
let lm_head = build_lm_head(cfg, vb, &base, quant)?;
|
||||||
let model = Self { base, lm_head };
|
let vision = load_replicated_vision_tower(&config, vb)?;
|
||||||
|
let image_token_id = config.image_token_id;
|
||||||
|
let model = Self {
|
||||||
|
base,
|
||||||
|
lm_head,
|
||||||
|
vision,
|
||||||
|
image_token_id,
|
||||||
|
};
|
||||||
log_construction_complete(cfg, rank, world_size, quant, model.device());
|
log_construction_complete(cfg, rank, world_size, quant, model.device());
|
||||||
Ok(model)
|
Ok(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// True when this TP load materialised a replicated vision tower.
|
||||||
|
/// Drives capability advertising and the Stage 3 vision dispatch.
|
||||||
|
pub fn has_vision(&self) -> bool {
|
||||||
|
self.vision.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `<|image_pad|>` sentinel id, when known.
|
||||||
|
pub fn image_token_id(&self) -> Option<u32> {
|
||||||
|
self.image_token_id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encode one preprocessed `(C, H, W)` image into LM-side patch
|
||||||
|
/// embeddings `(N_lm, hidden)` via this rank's replicated tower.
|
||||||
|
/// Errors when loaded without a vision tower.
|
||||||
|
pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> {
|
||||||
|
self.vision
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| {
|
||||||
|
anyhow::anyhow!(
|
||||||
|
"encode_image: this TP Qwen3.6 load has no vision tower \
|
||||||
|
(config.json::vision_config absent or weights missing)"
|
||||||
|
)
|
||||||
|
})?
|
||||||
|
.forward(image)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> candle_core::Result<Tensor> {
|
||||||
let (_, l) = input.dims2()?;
|
let (_, l) = input.dims2()?;
|
||||||
let hidden = self.base.forward(input, offset)?;
|
let hidden = self.base.forward(input, offset)?;
|
||||||
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Forward for a vision-prefill chunk (optional image splice +
|
||||||
|
/// explicit interleaved-M-RoPE `position_ids`). Mirrors `forward`
|
||||||
|
/// but routes through `TpQwen3_5Model::forward_with_positions`.
|
||||||
|
pub fn forward_with_positions(
|
||||||
|
&mut self,
|
||||||
|
input: &Tensor,
|
||||||
|
offset: usize,
|
||||||
|
position_ids: &Tensor,
|
||||||
|
image_embeds: Option<&Tensor>,
|
||||||
|
image_token_id: Option<u32>,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
let (_, l) = input.dims2()?;
|
||||||
|
let hidden = self.base.forward_with_positions(
|
||||||
|
input,
|
||||||
|
offset,
|
||||||
|
position_ids,
|
||||||
|
image_embeds,
|
||||||
|
image_token_id,
|
||||||
|
)?;
|
||||||
|
hidden.i((.., l - 1.., ..))?.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// End-to-end image prefill on one rank: encode each preprocessed
|
||||||
|
/// `(C, H, W)` pixel tensor through this rank's replicated tower,
|
||||||
|
/// concatenate the per-image embeddings along the patch axis, and
|
||||||
|
/// forward with the splice. Shared by the leader (`TpLeaderModel`)
|
||||||
|
/// and the subprocess worker (`WorkerModel`) so every rank runs the
|
||||||
|
/// identical encode → splice → forward and keeps the replicated
|
||||||
|
/// hidden state in lockstep. Returns last-position logits
|
||||||
|
/// `(B, 1, vocab)`, same contract as `forward`.
|
||||||
|
/// Encode every preprocessed `(C,H,W)` image once through this
|
||||||
|
/// rank's replicated tower and concatenate along the patch axis →
|
||||||
|
/// `(sum_patches, hidden)`. Done once per prefill, not per chunk.
|
||||||
|
fn encode_images_concat(&self, image_pixels: &[Tensor]) -> candle_core::Result<Tensor> {
|
||||||
|
let mut per_image = Vec::with_capacity(image_pixels.len());
|
||||||
|
for (idx, img) in image_pixels.iter().enumerate() {
|
||||||
|
let embed = self
|
||||||
|
.encode_image(img)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("encode image[{idx}]: {e:#}")))?;
|
||||||
|
per_image.push(embed);
|
||||||
|
}
|
||||||
|
Tensor::cat(&per_image.iter().collect::<Vec<_>>(), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Chunked image prefill on one rank. Encodes the image(s) once,
|
||||||
|
/// then walks the (pre-expanded) prompt in `chunk_size`-token
|
||||||
|
/// windows — exactly like the text `chunked_prefill_tp` — splicing
|
||||||
|
/// the patch embeddings into whichever chunk(s) carry `<|image_pad|>`
|
||||||
|
/// positions. Activation memory is bounded by the chunk, not the
|
||||||
|
/// full prompt, so a long vision context no longer single-shot-OOMs.
|
||||||
|
///
|
||||||
|
/// Every rank runs the identical chunk sequence (same `tokens.len()`
|
||||||
|
/// and `chunk_size`), so the row-parallel `AllReduce`s pair up
|
||||||
|
/// chunk-by-chunk across ranks with no extra synchronisation. The KV
|
||||||
|
/// cache accumulates across chunks via the growing offset; only the
|
||||||
|
/// final chunk's last-position logits are returned (intermediate
|
||||||
|
/// chunks just populate the cache, same as the text path).
|
||||||
|
pub fn prefill_with_images_chunked(
|
||||||
|
&mut self,
|
||||||
|
tokens: &[u32],
|
||||||
|
base_offset: usize,
|
||||||
|
image_pixels: &[Tensor],
|
||||||
|
image_token_id: u32,
|
||||||
|
chunk_size: usize,
|
||||||
|
) -> candle_core::Result<Tensor> {
|
||||||
|
if image_pixels.is_empty() {
|
||||||
|
candle_core::bail!("prefill_with_images_chunked: called with zero images");
|
||||||
|
}
|
||||||
|
if tokens.is_empty() {
|
||||||
|
candle_core::bail!("prefill_with_images_chunked: empty prompt");
|
||||||
|
}
|
||||||
|
let chunk_size = chunk_size.max(1);
|
||||||
|
let device = self.device().clone();
|
||||||
|
let image_embeds = self.encode_images_concat(image_pixels)?;
|
||||||
|
|
||||||
|
// Each image's LM grid (lm_gh, lm_gw) = (h/factor, w/factor),
|
||||||
|
// factor = patch×merge. Recomputed per rank from this rank's own
|
||||||
|
// pixel tensors — deterministic, so every rank's grids (and hence
|
||||||
|
// M-RoPE positions) match without crossing the RPC (#14).
|
||||||
|
let factor = self
|
||||||
|
.vision
|
||||||
|
.as_ref()
|
||||||
|
.map(|v| {
|
||||||
|
let c = v.config();
|
||||||
|
c.patch_size * c.spatial_merge_size
|
||||||
|
})
|
||||||
|
.ok_or_else(|| {
|
||||||
|
candle_core::Error::Msg(
|
||||||
|
"prefill_with_images_chunked: loaded without a vision tower".into(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let grids: Vec<(usize, usize)> = image_pixels
|
||||||
|
.iter()
|
||||||
|
.map(|t| {
|
||||||
|
let (_, h, w) = t.dims3()?;
|
||||||
|
Ok::<(usize, usize), candle_core::Error>((h / factor, w / factor))
|
||||||
|
})
|
||||||
|
.collect::<candle_core::Result<Vec<_>>>()?;
|
||||||
|
|
||||||
|
// Interleaved-M-RoPE 3D position ids for the whole prompt,
|
||||||
|
// computed once and sliced per chunk so every rank assigns image
|
||||||
|
// tokens their grid coordinates (and text after an image resumes
|
||||||
|
// from the compressed counter). `rope_delta` is stored on the base
|
||||||
|
// model for the decode that follows this prefill. Every chunk —
|
||||||
|
// text or image — uses the M-RoPE slice, because each image shifts
|
||||||
|
// the positions of the text around it.
|
||||||
|
let (text, height, width, delta) =
|
||||||
|
crate::harness::arch::qwen3_5::rope::get_rope_index(tokens, image_token_id, &grids)
|
||||||
|
.map_err(|e| candle_core::Error::Msg(format!("get_rope_index: {e}")))?;
|
||||||
|
self.base.set_rope_delta(delta);
|
||||||
|
let full_pos = crate::harness::arch::qwen3_5::rope::mrope_position_tensor(
|
||||||
|
&text, &height, &width, &device,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut last_logits: Option<Tensor> = None;
|
||||||
|
// Rows of `image_embeds` already spliced by earlier chunks. The
|
||||||
|
// `<|image_pad|>` run is contiguous, so chunks consume embedding
|
||||||
|
// rows in order.
|
||||||
|
let mut img_off = 0usize;
|
||||||
|
let mut start = 0usize;
|
||||||
|
while start < tokens.len() {
|
||||||
|
let end = (start + chunk_size).min(tokens.len());
|
||||||
|
let chunk = &tokens[start..end];
|
||||||
|
let input = Tensor::new(chunk, &device)?.unsqueeze(0)?;
|
||||||
|
let pos_slice = full_pos.narrow(1, start, end - start)?;
|
||||||
|
let n_here = chunk.iter().filter(|&&t| t == image_token_id).count();
|
||||||
|
let logits = if n_here == 0 {
|
||||||
|
self.forward_with_positions(&input, base_offset + start, &pos_slice, None, None)?
|
||||||
|
} else {
|
||||||
|
// Splice the next `n_here` patch rows at this chunk's
|
||||||
|
// local image-pad positions.
|
||||||
|
let rows = image_embeds.narrow(0, img_off, n_here)?;
|
||||||
|
img_off += n_here;
|
||||||
|
self.forward_with_positions(
|
||||||
|
&input,
|
||||||
|
base_offset + start,
|
||||||
|
&pos_slice,
|
||||||
|
Some(&rows),
|
||||||
|
Some(image_token_id),
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
last_logits = Some(logits);
|
||||||
|
start = end;
|
||||||
|
}
|
||||||
|
last_logits
|
||||||
|
.ok_or_else(|| candle_core::Error::Msg("prefill_with_images_chunked: no chunks".into()))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn clear_kv_cache(&mut self) {
|
pub fn clear_kv_cache(&mut self) {
|
||||||
self.base.clear_kv_cache();
|
self.base.clear_kv_cache();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,6 +47,34 @@ impl WorkerModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Chunked image prefill on this rank. Only the vision-capable
|
||||||
|
/// `qwen3_5` arch has a replicated tower; the dense `qwen3` arch
|
||||||
|
/// errors. The returned logits are discarded by the caller (the
|
||||||
|
/// leader samples from its own rank-0 copy) — the value is the NCCL
|
||||||
|
/// collectives the forward issues, chunk by chunk in lockstep with
|
||||||
|
/// the leader.
|
||||||
|
fn prefill_with_images_chunked(
|
||||||
|
&mut self,
|
||||||
|
tokens: &[u32],
|
||||||
|
base_offset: usize,
|
||||||
|
image_pixels: &[candle_core::Tensor],
|
||||||
|
image_token_id: u32,
|
||||||
|
chunk_size: usize,
|
||||||
|
) -> candle_core::Result<candle_core::Tensor> {
|
||||||
|
match self {
|
||||||
|
WorkerModel::Qwen3_5(m) => m.prefill_with_images_chunked(
|
||||||
|
tokens,
|
||||||
|
base_offset,
|
||||||
|
image_pixels,
|
||||||
|
image_token_id,
|
||||||
|
chunk_size,
|
||||||
|
),
|
||||||
|
WorkerModel::Qwen3(_) => {
|
||||||
|
candle_core::bail!("prefill_with_images_chunked: qwen3 (dense) has no vision tower")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn clear_kv_cache(&mut self) {
|
fn clear_kv_cache(&mut self) {
|
||||||
match self {
|
match self {
|
||||||
WorkerModel::Qwen3(m) => m.clear_kv_cache(),
|
WorkerModel::Qwen3(m) => m.clear_kv_cache(),
|
||||||
@@ -167,6 +195,21 @@ impl WorkerState {
|
|||||||
tokens,
|
tokens,
|
||||||
offset,
|
offset,
|
||||||
} => self.handle_generate_step(&model_id, tokens, offset),
|
} => self.handle_generate_step(&model_id, tokens, offset),
|
||||||
|
WorkerRequest::GenerateStepWithImages {
|
||||||
|
model_id,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
image_token_id,
|
||||||
|
image_data_uris,
|
||||||
|
chunk_size,
|
||||||
|
} => self.handle_generate_step_with_images(
|
||||||
|
&model_id,
|
||||||
|
tokens,
|
||||||
|
offset,
|
||||||
|
image_token_id,
|
||||||
|
image_data_uris,
|
||||||
|
chunk_size,
|
||||||
|
),
|
||||||
WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id),
|
WorkerRequest::ClearKvCache { model_id } => self.handle_clear_kv_cache(&model_id),
|
||||||
WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id),
|
WorkerRequest::UnloadModel { model_id } => self.handle_unload_model(&model_id),
|
||||||
WorkerRequest::Shutdown => WorkerResponse::Bye,
|
WorkerRequest::Shutdown => WorkerResponse::Bye,
|
||||||
@@ -418,6 +461,117 @@ impl WorkerState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Image-bearing prefill on this rank. Preprocesses each source data
|
||||||
|
/// URI through the same deterministic `preprocess_data_uri` the
|
||||||
|
/// leader runs, encodes through this rank's replicated tower, and
|
||||||
|
/// splices + forwards. The logits are discarded (the leader samples
|
||||||
|
/// from rank 0); the row-parallel `AllReduce`s are the point.
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn handle_generate_step_with_images(
|
||||||
|
&mut self,
|
||||||
|
model_id: &str,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
offset: usize,
|
||||||
|
image_token_id: u32,
|
||||||
|
image_data_uris: Vec<String>,
|
||||||
|
chunk_size: usize,
|
||||||
|
) -> WorkerResponse {
|
||||||
|
use crate::harness::preprocess::{PreprocessProfile, preprocess_data_uri};
|
||||||
|
use candle_core::Tensor;
|
||||||
|
|
||||||
|
if image_data_uris.is_empty() {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "bad_request".into(),
|
||||||
|
message: "GenerateStepWithImages with zero images".into(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
let Some(model) = self.models.get_mut(model_id) else {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "model_not_loaded".into(),
|
||||||
|
message: format!("model '{model_id}' not loaded on rank {}", self.config.rank),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
let device = model.device().clone();
|
||||||
|
|
||||||
|
// Preprocess each image identically to the leader so the encoded
|
||||||
|
// embeddings — and thus the spliced hidden state and per-image
|
||||||
|
// grids — match across ranks. Native-aspect `smart_resize` (#14);
|
||||||
|
// deterministic, so each rank derives the same dims.
|
||||||
|
let profile = PreprocessProfile::qwen3_6();
|
||||||
|
let mut pixels: Vec<Tensor> = Vec::with_capacity(image_data_uris.len());
|
||||||
|
for (idx, uri) in image_data_uris.iter().enumerate() {
|
||||||
|
let (px, h, w) = match preprocess_data_uri(uri, &profile) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "bad_request".into(),
|
||||||
|
message: format!("preprocess image[{idx}]: {e:#}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match Tensor::from_vec(px, (3, h as usize, w as usize), &device) {
|
||||||
|
Ok(t) => pixels.push(t),
|
||||||
|
Err(e) => {
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "forward_failed".into(),
|
||||||
|
message: format!("build image[{idx}] tensor: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
tracing::debug!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
tokens = tokens.len(),
|
||||||
|
offset,
|
||||||
|
images = pixels.len(),
|
||||||
|
chunk_size,
|
||||||
|
"worker GenerateStepWithImages: chunked prefill starting"
|
||||||
|
);
|
||||||
|
// Drop the logits — the leader samples from its own rank-0 copy.
|
||||||
|
// The chunked prefill builds its own per-chunk input tensors.
|
||||||
|
if let Err(e) =
|
||||||
|
model.prefill_with_images_chunked(&tokens, offset, &pixels, image_token_id, chunk_size)
|
||||||
|
{
|
||||||
|
tracing::warn!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis(),
|
||||||
|
error = %e,
|
||||||
|
"worker GenerateStepWithImages: forward failed"
|
||||||
|
);
|
||||||
|
return WorkerResponse::Error {
|
||||||
|
kind: "forward_failed".into(),
|
||||||
|
message: format!("TP image forward: {e}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
tracing::debug!(
|
||||||
|
rank = self.config.rank,
|
||||||
|
model = %model_id,
|
||||||
|
elapsed_ms = start.elapsed().as_millis(),
|
||||||
|
"worker GenerateStepWithImages: forward done"
|
||||||
|
);
|
||||||
|
WorkerResponse::GenerateStepOk
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
fn handle_generate_step_with_images(
|
||||||
|
&mut self,
|
||||||
|
_model_id: &str,
|
||||||
|
_tokens: Vec<u32>,
|
||||||
|
_offset: usize,
|
||||||
|
_image_token_id: u32,
|
||||||
|
_image_data_uris: Vec<String>,
|
||||||
|
_chunk_size: usize,
|
||||||
|
) -> WorkerResponse {
|
||||||
|
WorkerResponse::Error {
|
||||||
|
kind: "cuda_feature_not_enabled".into(),
|
||||||
|
message: "GenerateStepWithImages requires --features cuda".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
fn handle_clear_kv_cache(&mut self, model_id: &str) -> WorkerResponse {
|
fn handle_clear_kv_cache(&mut self, model_id: &str) -> WorkerResponse {
|
||||||
let Some(model) = self.models.get_mut(model_id) else {
|
let Some(model) = self.models.get_mut(model_id) else {
|
||||||
|
|||||||
@@ -646,6 +646,54 @@ mod tests {
|
|||||||
assert_eq!(parts[1]["image_url"]["url"], "data:image/png;base64,AAA=");
|
assert_eq!(parts[1]["image_url"]["url"], "data:image/png;base64,AAA=");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn multiple_images_translate_in_order_and_tolerate_detail() {
|
||||||
|
// C2: a Responses request carrying several InputImage parts
|
||||||
|
// (with `detail` set) must translate to a chat Parts array that
|
||||||
|
// preserves image order and the `image_url.url` shape the chat
|
||||||
|
// vision path (`extract_images_from_request`) walks. The
|
||||||
|
// `detail` hint has no chat-completions analogue we forward, so
|
||||||
|
// it's dropped — but it must not break translation.
|
||||||
|
let req = ResponsesRequest {
|
||||||
|
model: "m".into(),
|
||||||
|
input: ResponsesInput::Items(vec![ResponsesInputItem::Message {
|
||||||
|
role: "user".into(),
|
||||||
|
content: ResponsesMessageContent::Parts(vec![
|
||||||
|
ResponsesContentPart::InputText {
|
||||||
|
text: "compare these".into(),
|
||||||
|
},
|
||||||
|
ResponsesContentPart::InputImage {
|
||||||
|
image_url: "data:image/png;base64,FIRST".into(),
|
||||||
|
detail: Some("high".into()),
|
||||||
|
},
|
||||||
|
ResponsesContentPart::InputImage {
|
||||||
|
image_url: "data:image/png;base64,SECOND".into(),
|
||||||
|
detail: None,
|
||||||
|
},
|
||||||
|
]),
|
||||||
|
}]),
|
||||||
|
instructions: None,
|
||||||
|
stream: false,
|
||||||
|
max_output_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
previous_response_id: None,
|
||||||
|
extra: Value::Object(Default::default()),
|
||||||
|
};
|
||||||
|
let chat = request_to_chat(req).unwrap();
|
||||||
|
let parts = match &chat.messages[0].content {
|
||||||
|
MessageContent::Parts(p) => p.clone(),
|
||||||
|
other => panic!("expected Parts, got {other:?}"),
|
||||||
|
};
|
||||||
|
// text + two images, in input order.
|
||||||
|
assert_eq!(parts.len(), 3);
|
||||||
|
assert_eq!(parts[0]["type"], "text");
|
||||||
|
assert_eq!(parts[1]["image_url"]["url"], "data:image/png;base64,FIRST");
|
||||||
|
assert_eq!(parts[2]["image_url"]["url"], "data:image/png;base64,SECOND");
|
||||||
|
// `detail` is not forwarded into the chat image_url object.
|
||||||
|
assert!(parts[1]["image_url"].get("detail").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn text_only_parts_collapse_to_string() {
|
fn text_only_parts_collapse_to_string() {
|
||||||
let req = ResponsesRequest {
|
let req = ResponsesRequest {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ use axum::http::StatusCode;
|
|||||||
use axum::response::{IntoResponse, Json};
|
use axum::response::{IntoResponse, Json};
|
||||||
use axum::routing::get;
|
use axum::routing::get;
|
||||||
use cortex_core::harness::ModelSpec;
|
use cortex_core::harness::ModelSpec;
|
||||||
|
use cortex_core::source::ModelSourceId;
|
||||||
use neuron::harness::preflight::{PreflightError, SourceFormat, preflight};
|
use neuron::harness::preflight::{PreflightError, SourceFormat, preflight};
|
||||||
use serde_json::{Value, json};
|
use serde_json::{Value, json};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -89,6 +90,15 @@ fn spec(model_id: &str, tp: Option<u32>, quant: Option<&str>) -> ModelSpec {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build a `ModelSourceId` from a bare `org/name` test input,
|
||||||
|
/// substituting the default scheme so the mock route key matches.
|
||||||
|
fn sid(model_id: &str) -> ModelSourceId {
|
||||||
|
model_id
|
||||||
|
.parse::<ModelSourceId>()
|
||||||
|
.expect("test model_id parses")
|
||||||
|
.with_default_scheme("huggingface")
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn preflight_gguf_tp_rejected_over_http() {
|
async fn preflight_gguf_tp_rejected_over_http() {
|
||||||
let cache = tempfile::tempdir().expect("tempdir");
|
let cache = tempfile::tempdir().expect("tempdir");
|
||||||
@@ -107,7 +117,7 @@ async fn preflight_gguf_tp_rejected_over_http() {
|
|||||||
|
|
||||||
let api = build_api(&endpoint, cache.path());
|
let api = build_api(&endpoint, cache.path());
|
||||||
let s = spec("HauhauCS/Qwen3.6", Some(2), Some("q6k"));
|
let s = spec("HauhauCS/Qwen3.6", Some(2), Some("q6k"));
|
||||||
let err = preflight(&api, &s).await.unwrap_err();
|
let err = preflight(&api, &sid(&s.model_id), &s).await.unwrap_err();
|
||||||
match err {
|
match err {
|
||||||
PreflightError::TpRequiresSafetensors {
|
PreflightError::TpRequiresSafetensors {
|
||||||
model_id,
|
model_id,
|
||||||
@@ -115,7 +125,9 @@ async fn preflight_gguf_tp_rejected_over_http() {
|
|||||||
gguf_quants,
|
gguf_quants,
|
||||||
..
|
..
|
||||||
} => {
|
} => {
|
||||||
assert_eq!(model_id, "HauhauCS/Qwen3.6");
|
// Scheme prefix surfaces in error display now that
|
||||||
|
// preflight is source-aware.
|
||||||
|
assert_eq!(model_id, "huggingface:HauhauCS/Qwen3.6");
|
||||||
assert_eq!(tp_size, 2);
|
assert_eq!(tp_size, 2);
|
||||||
assert_eq!(gguf_quants.len(), 3);
|
assert_eq!(gguf_quants.len(), 3);
|
||||||
}
|
}
|
||||||
@@ -140,7 +152,7 @@ async fn preflight_gguf_quant_suggestion_over_http() {
|
|||||||
|
|
||||||
let api = build_api(&endpoint, cache.path());
|
let api = build_api(&endpoint, cache.path());
|
||||||
let s = spec("HauhauCS/Qwen3.6", Some(1), Some("q6k"));
|
let s = spec("HauhauCS/Qwen3.6", Some(1), Some("q6k"));
|
||||||
let err = preflight(&api, &s).await.unwrap_err();
|
let err = preflight(&api, &sid(&s.model_id), &s).await.unwrap_err();
|
||||||
match err {
|
match err {
|
||||||
PreflightError::QuantNotFound {
|
PreflightError::QuantNotFound {
|
||||||
requested,
|
requested,
|
||||||
@@ -176,7 +188,9 @@ async fn preflight_dense_safetensors_tp_ok() {
|
|||||||
|
|
||||||
let api = build_api(&endpoint, cache.path());
|
let api = build_api(&endpoint, cache.path());
|
||||||
let s = spec("Qwen/Q3-30B", Some(2), Some("q5k"));
|
let s = spec("Qwen/Q3-30B", Some(2), Some("q5k"));
|
||||||
let plan = preflight(&api, &s).await.expect("dense+tp should succeed");
|
let plan = preflight(&api, &sid(&s.model_id), &s)
|
||||||
|
.await
|
||||||
|
.expect("dense+tp should succeed");
|
||||||
assert_eq!(plan.tp_size, 2);
|
assert_eq!(plan.tp_size, 2);
|
||||||
assert!(plan.picked_quant_file.is_none());
|
assert!(plan.picked_quant_file.is_none());
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
@@ -197,7 +211,7 @@ async fn preflight_gguf_single_gpu_good_quant() {
|
|||||||
|
|
||||||
let api = build_api(&endpoint, cache.path());
|
let api = build_api(&endpoint, cache.path());
|
||||||
let s = spec("HauhauCS/Qwen3.6", Some(1), Some("q6_k_p"));
|
let s = spec("HauhauCS/Qwen3.6", Some(1), Some("q6_k_p"));
|
||||||
let plan = preflight(&api, &s)
|
let plan = preflight(&api, &sid(&s.model_id), &s)
|
||||||
.await
|
.await
|
||||||
.expect("good quant should succeed");
|
.expect("good quant should succeed");
|
||||||
assert_eq!(plan.tp_size, 1);
|
assert_eq!(plan.tp_size, 1);
|
||||||
@@ -219,7 +233,7 @@ async fn preflight_repo_fetch_failed_on_404() {
|
|||||||
|
|
||||||
let api = build_api(&endpoint, cache.path());
|
let api = build_api(&endpoint, cache.path());
|
||||||
let s = spec("DoesNot/Exist", Some(1), None);
|
let s = spec("DoesNot/Exist", Some(1), None);
|
||||||
let err = preflight(&api, &s).await.unwrap_err();
|
let err = preflight(&api, &sid(&s.model_id), &s).await.unwrap_err();
|
||||||
assert!(
|
assert!(
|
||||||
matches!(err, PreflightError::RepoFetchFailed { .. }),
|
matches!(err, PreflightError::RepoFetchFailed { .. }),
|
||||||
"expected RepoFetchFailed, got {err:?}"
|
"expected RepoFetchFailed, got {err:?}"
|
||||||
@@ -238,7 +252,7 @@ async fn preflight_empty_repo_rejected() {
|
|||||||
|
|
||||||
let api = build_api(&endpoint, cache.path());
|
let api = build_api(&endpoint, cache.path());
|
||||||
let s = spec("Empty/Repo", Some(1), None);
|
let s = spec("Empty/Repo", Some(1), None);
|
||||||
let err = preflight(&api, &s).await.unwrap_err();
|
let err = preflight(&api, &sid(&s.model_id), &s).await.unwrap_err();
|
||||||
assert!(
|
assert!(
|
||||||
matches!(err, PreflightError::EmptyRepo { .. }),
|
matches!(err, PreflightError::EmptyRepo { .. }),
|
||||||
"expected EmptyRepo, got {err:?}"
|
"expected EmptyRepo, got {err:?}"
|
||||||
@@ -264,6 +278,8 @@ async fn preflight_mixed_repo_prefers_safetensors() {
|
|||||||
// TP=2 + quant should succeed via the dense path even though a
|
// TP=2 + quant should succeed via the dense path even though a
|
||||||
// GGUF is present — the dense path handles ISQ.
|
// GGUF is present — the dense path handles ISQ.
|
||||||
let s = spec("Mixed/Repo", Some(2), Some("q5k"));
|
let s = spec("Mixed/Repo", Some(2), Some("q5k"));
|
||||||
let plan = preflight(&api, &s).await.expect("mixed should succeed");
|
let plan = preflight(&api, &sid(&s.model_id), &s)
|
||||||
|
.await
|
||||||
|
.expect("mixed should succeed");
|
||||||
assert!(matches!(plan.format, SourceFormat::Mixed { .. }));
|
assert!(matches!(plan.format, SourceFormat::Mixed { .. }));
|
||||||
}
|
}
|
||||||
|
|||||||
176
doc/vision-qwen3_6-spec.md
Normal file
176
doc/vision-qwen3_6-spec.md
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
# Qwen3.6-27B vision specification (Stage A0)
|
||||||
|
|
||||||
|
Sourced from beast's local cache on 2026-06-01:
|
||||||
|
`/archive3/llm-cache/models--Qwen--Qwen3.6-27B/snapshots/6a9e13bd6fc8f0983b9b99948120bc37f49c13e9/`.
|
||||||
|
|
||||||
|
Single source of truth for Stages A–D of the vision plan in
|
||||||
|
`~/.claude/plans/foamy-twirling-catmull.md`. Umbrella issue:
|
||||||
|
[#3](https://git.lair.cafe/helexa/cortex/issues/3).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Top-level shape
|
||||||
|
|
||||||
|
The model is a unified text+vision architecture (`Qwen3_5ForConditionalGeneration`,
|
||||||
|
`model_type: qwen3_5`) with three weight sections under a single safetensors
|
||||||
|
index. Counts from `model.safetensors.index.json`:
|
||||||
|
|
||||||
|
| Prefix | Tensors | Role |
|
||||||
|
|---|---|---|
|
||||||
|
| `model.language_model.*` | 850 | LM (currently loaded) |
|
||||||
|
| `model.visual.*` | 333 | Vision tower (currently filtered out at `arch/qwen3_5/mod.rs:228-230`) |
|
||||||
|
| `mtp.*` | 15 | Multi-token-prediction heads (filtered, out of scope) |
|
||||||
|
| `lm_head.weight` | 1 | LM head |
|
||||||
|
|
||||||
|
Vision tensors live in shards `model-00007-of-00015.safetensors` and
|
||||||
|
`model-00008-of-00015.safetensors` (2 of the 15 safetensors). Loading just
|
||||||
|
these two for vision-tower-only smoke tests is feasible.
|
||||||
|
|
||||||
|
## Vision tower architecture (`model.visual.*`)
|
||||||
|
|
||||||
|
From `config.json::vision_config`:
|
||||||
|
|
||||||
|
```
|
||||||
|
depth: 27 (transformer blocks)
|
||||||
|
hidden_size: 1152 (vision token dim)
|
||||||
|
num_heads: 16 (per-block self-attention)
|
||||||
|
intermediate_size: 4304 (MLP hidden)
|
||||||
|
patch_size: 16 (16×16 spatial patches)
|
||||||
|
temporal_patch_size: 2 (video frame pairing; irrelevant for stills)
|
||||||
|
spatial_merge_size: 2 (2×2 spatial merge in the merger → 4 patches/LM token)
|
||||||
|
num_position_embeddings: 2304 (learned pos embed slots — max patch sequence length)
|
||||||
|
in_channels: 3 (RGB)
|
||||||
|
hidden_act: gelu_pytorch_tanh (GELU with tanh approximation, not exact GELU)
|
||||||
|
out_hidden_size: 5120 (= LM hidden_size, merger output dim)
|
||||||
|
deepstack_visual_indexes: [] (no deep-stack visual indexes)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Module inventory (per-block and global)
|
||||||
|
|
||||||
|
Global:
|
||||||
|
- `model.visual.patch_embed.proj.{weight, bias}` — Conv2d (3 → 1152, kernel 16×16, stride 16). Turns image patches into tokens.
|
||||||
|
- `model.visual.pos_embed.weight` — Learned position embedding, shape `(2304, 1152)`.
|
||||||
|
- `model.visual.merger.{norm, linear_fc1, linear_fc2}` — The projector that merges 2×2 patches and projects to LM hidden_size (1152 → 5120). All weights have biases.
|
||||||
|
|
||||||
|
Per block (×27, named `model.visual.blocks.{0..26}`):
|
||||||
|
- `norm1.{weight, bias}` — **LayerNorm** before attention (with bias — not RmsNorm).
|
||||||
|
- `attn.qkv.{weight, bias}` — Fused QKV linear (1152 → 3·1152 = 3456).
|
||||||
|
- `attn.proj.{weight, bias}` — Attention output projection (1152 → 1152).
|
||||||
|
- `norm2.{weight, bias}` — LayerNorm before MLP.
|
||||||
|
- `mlp.linear_fc1.{weight, bias}` — MLP up-projection (1152 → 4304).
|
||||||
|
- `mlp.linear_fc2.{weight, bias}` — MLP down-projection (4304 → 1152).
|
||||||
|
|
||||||
|
Pattern matches a standard ViT block with **pre-norm** layout (norm → attn → residual, norm → MLP → residual). Activation between fc1/fc2 is GELU-tanh-approx per `hidden_act`. No attention masking inside the vision tower (all patches attend to each other).
|
||||||
|
|
||||||
|
### Forward signature (target)
|
||||||
|
|
||||||
|
```
|
||||||
|
VisionTower::forward(
|
||||||
|
patches: Tensor [N, in_channels, patch_size, patch_size], # CPU-preprocessed RGB float patches
|
||||||
|
grid_thw: Option<(usize, usize, usize)>, # (t, h, w) patch grid for position lookup
|
||||||
|
) -> Tensor [N / (spatial_merge_size²), out_hidden_size] # = (N/4, 5120) for static images
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: the merger consumes 4 spatially-adjacent patches and emits 1 LM token. So an image producing 64×64 = 4096 patches yields 1024 LM-side image tokens.
|
||||||
|
|
||||||
|
## Image preprocessor (`preprocessor_config.json`)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"size": { "longest_edge": 16777216, "shortest_edge": 65536 },
|
||||||
|
"patch_size": 16,
|
||||||
|
"temporal_patch_size": 2,
|
||||||
|
"merge_size": 2,
|
||||||
|
"image_mean": [0.5, 0.5, 0.5],
|
||||||
|
"image_std": [0.5, 0.5, 0.5],
|
||||||
|
"processor_class": "Qwen3VLProcessor",
|
||||||
|
"image_processor_type": "Qwen2VLImageProcessorFast"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Reading:
|
||||||
|
|
||||||
|
- `image_mean = image_std = 0.5` → normalisation is simply `(x/255 - 0.5) / 0.5 = 2*x/255 - 1`, mapping `[0,255]` → `[-1, 1]`. No imagenet-style mean/std.
|
||||||
|
- `size.{shortest_edge, longest_edge}` are **pixel counts**, not edge lengths. The `Qwen2VLImageProcessorFast` recipe picks a resolution within `[65,536 = 256², 16,777,216 = 4096²]` total pixels, snapping `h` and `w` to multiples of `patch_size × spatial_merge_size = 32` pixels.
|
||||||
|
- Stage A ships **fixed resolution**: pick a target pixel count (e.g. 448×448 = 200,704 px → 28×28 patches → 14×14 LM tokens after merger). Variable resolution deferred to issue [#14](https://git.lair.cafe/helexa/cortex/issues/14).
|
||||||
|
|
||||||
|
## Chat template (`chat_template.jinja`)
|
||||||
|
|
||||||
|
Image insertion (lines 8–18 of the template):
|
||||||
|
|
||||||
|
```jinja
|
||||||
|
{%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
|
||||||
|
...
|
||||||
|
{{- '<|vision_start|><|image_pad|><|vision_end|>' }}
|
||||||
|
```
|
||||||
|
|
||||||
|
Per image, the template emits **one `<|image_pad|>` token** flanked by `<|vision_start|>` and `<|vision_end|>` sentinels. The runtime must:
|
||||||
|
|
||||||
|
1. Render the template (preserving the single `<|image_pad|>` per image).
|
||||||
|
2. For each image, replace its single `<|image_pad|>` with N copies, where N is the number of LM tokens that image produces after the vision tower + merger (= `patches / spatial_merge_size²`).
|
||||||
|
3. Tokenize the expanded string → `input_ids`.
|
||||||
|
4. At forward time, locate positions where `input_ids == image_token_id` (248056) and splice in the vision tower's merger output.
|
||||||
|
|
||||||
|
Token IDs (top of `config.json`):
|
||||||
|
- `vision_start_token_id`: 248053
|
||||||
|
- `vision_end_token_id`: 248054
|
||||||
|
- `image_token_id`: 248056
|
||||||
|
- `video_token_id`: 248057 (out of scope)
|
||||||
|
- `bos_token_id`: 248044
|
||||||
|
- `eos_token_id`: 248044, 248046 (per `generation_config.json`)
|
||||||
|
|
||||||
|
System messages cannot contain images (template raises). Other template-side details:
|
||||||
|
- `add_vision_id` (jinja arg, default false): emits `'Picture N: '` prefixes when true.
|
||||||
|
- `preserve_thinking` (jinja arg, default false): keeps `<think>` blocks from prior assistant turns in the rendered prompt.
|
||||||
|
- `enable_thinking` (jinja arg, default true): emits `<think>\n` (or skips it) at the end of the generation prompt.
|
||||||
|
|
||||||
|
The existing chat-template renderer in `crates/neuron/src/harness/chat_template.rs` already passes `MessageContent::Parts` to the Jinja context as a `Value::Array`; the template's `is iterable` branch (line 6 of the template) handles them. **The path is structurally in place** — Stage B just needs to do the `<|image_pad|>` expansion + token-position-aware splice.
|
||||||
|
|
||||||
|
## LM-side considerations
|
||||||
|
|
||||||
|
The LM's RoPE config uses **multi-axis RoPE (MRoPE)**:
|
||||||
|
|
||||||
|
```
|
||||||
|
rope_parameters: {
|
||||||
|
mrope_interleaved: true,
|
||||||
|
mrope_section: [11, 11, 10], # text + height + width components
|
||||||
|
partial_rotary_factor: 0.25,
|
||||||
|
rope_theta: 10000000,
|
||||||
|
rope_type: "default"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
MRoPE encodes spatial position alongside text position so the LM attention layers can reason about image-token spatial structure. The LM's existing forward path *may or may not* already implement this — the qwen3_5 module's doc-comment notes "numerical correctness vs the reference Python is not yet validated." Verifying MRoPE behaviour in the language model is out of Stage A scope (vision tower only) but will be required in Stage B (LM splice) and is tracked under the numerical-validation issue [#15](https://git.lair.cafe/helexa/cortex/issues/15).
|
||||||
|
|
||||||
|
`max_position_embeddings = 262144` (256 K context), so context-length limits are not a constraint for vision.
|
||||||
|
|
||||||
|
## Iteration target decision
|
||||||
|
|
||||||
|
The vision tower has its own self-contained weight tree and is small (~333 tensors in 2 shards, hidden_size 1152 vs LM's 5120). For Stage A specifically (vision-tower-only smoke), we **don't need a smaller iteration model** — we can:
|
||||||
|
|
||||||
|
- Build the Rust `VisionTower` struct against the spec above.
|
||||||
|
- Run unit tests with random tensor weights matching the exact shapes → assert forward produces correct output shape with finite values.
|
||||||
|
- Optionally: a CUDA-integration test that loads just the 2 vision shards from beast's cache (or on a smaller GPU like quadbrat's Ampere) and runs encode on a real image. Doesn't require loading the 27B LM at all.
|
||||||
|
|
||||||
|
This sidesteps the "develop against a smaller VL model" question for Stage A. Stage B (LM splice → end-to-end chat with vision) is where iteration speed becomes pressing; revisit there. The default scope pick 2a (smaller iteration model) is therefore deferred to Stage B planning — issue [#13](https://git.lair.cafe/helexa/cortex/issues/13) covers deployment validation regardless.
|
||||||
|
|
||||||
|
## Concrete Stage A1+ inputs
|
||||||
|
|
||||||
|
- Add deps to `crates/neuron/Cargo.toml`:
|
||||||
|
- `image = "0.25"`
|
||||||
|
- `base64 = "0.22"`
|
||||||
|
- Stage A2 preprocessor target resolution (fixed): **448×448 → 28×28 patches → 14×14 = 196 image tokens per image**. This balances minimum-patch-count for cheap tests against the model's expected input range.
|
||||||
|
- Stage A3 module structure: one `VisionTower` struct holding `patch_embed: Conv2d`, `pos_embed: Embedding`, `blocks: Vec<VisionBlock>`, `merger: Merger`. `VisionBlock` carries `norm1`, `norm2`, `attn`, `mlp`. Hand-roll using candle primitives.
|
||||||
|
- Stage A4 weight loading: extend `Qwen3_5ForCausalLM::new()` to construct `Some(VisionTower::new(vb.pp("model.visual"), config))` when `vision_config` is present in the parsed config.
|
||||||
|
- Stage A5 worker job: `Job::EncodeImage { handle, patches: Vec<f32>, patch_shape: (usize, usize, usize, usize, usize), reply: oneshot<Result<Vec<f32>>> }`. Patch shape = `(N, C, T, H, W)` where T=1 for static images.
|
||||||
|
|
||||||
|
## What this doc does NOT settle (deferred to issues)
|
||||||
|
|
||||||
|
- Numerical correctness of `VisionTower` output vs Python transformers
|
||||||
|
→ issue [#15](https://git.lair.cafe/helexa/cortex/issues/15).
|
||||||
|
- Variable image resolution
|
||||||
|
→ issue [#14](https://git.lair.cafe/helexa/cortex/issues/14).
|
||||||
|
- TP-vision (multi-rank vision tower)
|
||||||
|
→ issue [#12](https://git.lair.cafe/helexa/cortex/issues/12).
|
||||||
|
- 27B production deployment
|
||||||
|
→ issue [#13](https://git.lair.cafe/helexa/cortex/issues/13).
|
||||||
@@ -7,7 +7,8 @@
|
|||||||
# returns and what the router can cold-load on demand.
|
# returns and what the router can cold-load on demand.
|
||||||
#
|
#
|
||||||
# Field reference:
|
# Field reference:
|
||||||
# id - HuggingFace model id, exact match.
|
# id - Repo id in the source registry (e.g. "Qwen/Qwen3.6-27B").
|
||||||
|
# Exact match.
|
||||||
# harness - which engine handles inference (currently "candle").
|
# harness - which engine handles inference (currently "candle").
|
||||||
# quant - GGUF quantisation tag for the file in the HF repo
|
# quant - GGUF quantisation tag for the file in the HF repo
|
||||||
# (e.g. "Q4_K_M"). Omit/empty for the dense
|
# (e.g. "Q4_K_M"). Omit/empty for the dense
|
||||||
@@ -20,6 +21,11 @@
|
|||||||
# pinned_on - optional whitelist of neuron names. Non-empty
|
# pinned_on - optional whitelist of neuron names. Non-empty
|
||||||
# narrows feasibility to just those neurons and
|
# narrows feasibility to just those neurons and
|
||||||
# protects the model from LRU eviction there.
|
# protects the model from LRU eviction there.
|
||||||
|
# source - optional source scheme ("huggingface", "helexa",
|
||||||
|
# operator mirror tag). When set, cortex forwards
|
||||||
|
# the load to neuron as `scheme:id` so the daemon
|
||||||
|
# fetches from the right registry. Omit to let
|
||||||
|
# neuron substitute its own `default_source`.
|
||||||
|
|
||||||
# Tensor-parallel target — needs a neuron with at least 2 large GPUs.
|
# Tensor-parallel target — needs a neuron with at least 2 large GPUs.
|
||||||
# The example pins to a specific neuron name; adjust or remove the
|
# The example pins to a specific neuron name; adjust or remove the
|
||||||
@@ -49,6 +55,20 @@ vram_mb = 500
|
|||||||
min_devices = 1
|
min_devices = 1
|
||||||
min_device_vram_mb = 4000
|
min_device_vram_mb = 4000
|
||||||
|
|
||||||
|
# Helexa registry model — `source` pins this entry to the helexa
|
||||||
|
# scheme so cortex forwards `helexa:Helexa/Qwen3.6-27B-Uncensored` to
|
||||||
|
# neuron's /models/load. Requires the neuron config to declare a
|
||||||
|
# matching [harness.candle.sources.helexa] entry pointing at the
|
||||||
|
# helexa registry endpoint (see neuron.example.toml).
|
||||||
|
#
|
||||||
|
# [[models]]
|
||||||
|
# id = "Helexa/Qwen3.6-27B-Uncensored"
|
||||||
|
# harness = "candle"
|
||||||
|
# source = "helexa"
|
||||||
|
# vram_mb = 54000
|
||||||
|
# min_devices = 2
|
||||||
|
# min_device_vram_mb = 24000
|
||||||
|
|
||||||
# -- Tier aliases ------------------------------------------------------------
|
# -- Tier aliases ------------------------------------------------------------
|
||||||
# Optional. Clients can request inference against an alias (e.g.
|
# Optional. Clients can request inference against an alias (e.g.
|
||||||
# `model: "helexa/small"` in /v1/chat/completions) and cortex
|
# `model: "helexa/small"` in /v1/chat/completions) and cortex
|
||||||
|
|||||||
@@ -22,7 +22,9 @@ name = "candle"
|
|||||||
# HuggingFace cache directory for model weights.
|
# HuggingFace cache directory for model weights.
|
||||||
#
|
#
|
||||||
# Resolution order (first hit wins):
|
# Resolution order (first hit wins):
|
||||||
# 1. `hf_cache` here in this file.
|
# 1. `hf_cache` here in this file (applies to the synth `huggingface`
|
||||||
|
# source only — see [harness.candle.sources.*] below for explicit
|
||||||
|
# per-source paths).
|
||||||
# 2. `HF_HUB_CACHE` env var — same convention as the Python
|
# 2. `HF_HUB_CACHE` env var — same convention as the Python
|
||||||
# `huggingface_hub` library, so an existing cache directory shared
|
# `huggingface_hub` library, so an existing cache directory shared
|
||||||
# with other tooling can be reused without per-tool config.
|
# with other tooling can be reused without per-tool config.
|
||||||
@@ -36,6 +38,32 @@ name = "candle"
|
|||||||
# Environment=HF_HUB_CACHE=/archive/hf-cache
|
# Environment=HF_HUB_CACHE=/archive/hf-cache
|
||||||
# hf_cache = "/var/lib/neuron/hf-cache"
|
# hf_cache = "/var/lib/neuron/hf-cache"
|
||||||
|
|
||||||
|
# Default scheme applied to bare `org/name` model ids (those without a
|
||||||
|
# `scheme:` prefix). Defaults to "huggingface" when unset. Set to
|
||||||
|
# "helexa" to make `default_models = [{ model_id = "Helexa/Foo" }]`
|
||||||
|
# resolve via the helexa registry without prefixing every entry.
|
||||||
|
# default_source = "huggingface"
|
||||||
|
|
||||||
|
# Per-scheme source endpoints. Each scheme maps to an HF-compatible
|
||||||
|
# registry. The `huggingface` source is auto-synthesised pointing at
|
||||||
|
# `https://huggingface.co` when omitted; declare it explicitly here to
|
||||||
|
# override the endpoint, auth env, or cache dir.
|
||||||
|
#
|
||||||
|
# [harness.candle.sources.huggingface]
|
||||||
|
# endpoint = "https://huggingface.co"
|
||||||
|
# auth_env = "HF_TOKEN" # optional bearer token via env var
|
||||||
|
# cache_dir = "/archive3/llm-cache/huggingface"
|
||||||
|
#
|
||||||
|
# Add helexa (or any operator-run mirror speaking the HF-compatible
|
||||||
|
# wire format) by adding another sources entry. Caches are
|
||||||
|
# disambiguated per scheme so a mirror serving the same `org/name` as
|
||||||
|
# HF cannot collide on disk.
|
||||||
|
#
|
||||||
|
# [harness.candle.sources.helexa]
|
||||||
|
# endpoint = "https://registry.helexa.ai"
|
||||||
|
# auth_env = "HELEXA_TOKEN"
|
||||||
|
# cache_dir = "/archive3/llm-cache/helexa"
|
||||||
|
|
||||||
# -- Default models ----------------------------------------------------------
|
# -- Default models ----------------------------------------------------------
|
||||||
# Models listed here are loaded automatically when the neuron service
|
# Models listed here are loaded automatically when the neuron service
|
||||||
# activates. Loading is sequential — a slow or failing entry doesn't
|
# activates. Loading is sequential — a slow or failing entry doesn't
|
||||||
|
|||||||
303
script/deploy.sh
303
script/deploy.sh
@@ -1,303 +0,0 @@
|
|||||||
#!/bin/env bash
|
|
||||||
#
|
|
||||||
# Rolling deploy across the helexa fleet, driven by asset/manifest.yml.
|
|
||||||
# Installs / upgrades cortex on the gateway host and the appropriate
|
|
||||||
# helexa-neuron-<flavour> package on each neuron host, then restarts
|
|
||||||
# their services.
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
|
||||||
REPO_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
|
||||||
MANIFEST="${REPO_DIR}/asset/manifest.yml"
|
|
||||||
|
|
||||||
if [[ ! -f "${MANIFEST}" ]]; then
|
|
||||||
echo "fatal: manifest not found at ${MANIFEST}" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Parse the manifest with yq. NOTE: this expects the pip-installed yq
|
|
||||||
# (a jq wrapper using jq syntax) — `pip install yq`. The Fedora rpm
|
|
||||||
# `yq` is mikefarah/yq and uses different (yaml-native) syntax; if a
|
|
||||||
# host has that one instead these queries will fail.
|
|
||||||
cortex_host=$(yq -r '.cortex.host' "${MANIFEST}")
|
|
||||||
|
|
||||||
# Emit one TAB-separated 'host\tflavour' line per neuron.
|
|
||||||
mapfile -t neuron_entries < <(
|
|
||||||
yq -r '.neurons[] | .host + "\t" + .flavour' "${MANIFEST}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return the installed package's "version-release" string, or
|
|
||||||
# "(not installed)" when rpm reports the package as absent. Capture
|
|
||||||
# rpm's output into a variable so its "package X is not installed"
|
|
||||||
# stdout message (rpm writes that to stdout, not stderr, when -q fails)
|
|
||||||
# doesn't leak into the result.
|
|
||||||
installed_nvr() {
|
|
||||||
local host="$1" pkg="$2"
|
|
||||||
local nvr
|
|
||||||
if nvr=$(ssh "${host}" "rpm -q --qf '%{version}-%{release}' ${pkg} 2>/dev/null"); then
|
|
||||||
echo "${nvr}"
|
|
||||||
else
|
|
||||||
echo "(not installed)"
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# Ensure the rpm.lair.cafe unstable repo is configured AND enabled on
|
|
||||||
# the remote host.
|
|
||||||
#
|
|
||||||
# The upstream .repo file at https://rpm.lair.cafe/lair-cafe-unstable.repo
|
|
||||||
# ships with `enabled=0` so a host that just fetched it won't start
|
|
||||||
# pulling unstable packages by accident. We have to explicitly flip
|
|
||||||
# enabled=1 via `dnf config-manager setopt`. Both addrepo and setopt
|
|
||||||
# are idempotent.
|
|
||||||
#
|
|
||||||
# Non-fatal — if either step fails the subsequent `dnf install` will
|
|
||||||
# surface a clearer diagnostic on its own.
|
|
||||||
ensure_lair_repo() {
|
|
||||||
local host="$1"
|
|
||||||
if ! ssh "${host}" "test -f /etc/yum.repos.d/lair-cafe-unstable.repo" 2>/dev/null; then
|
|
||||||
echo "[${host}] adding rpm.lair.cafe unstable repo"
|
|
||||||
if ! ssh "${host}" sudo dnf config-manager addrepo \
|
|
||||||
--from-repofile=https://rpm.lair.cafe/lair-cafe-unstable.repo \
|
|
||||||
>/dev/null 2>&1; then
|
|
||||||
echo "[${host}] WARNING: failed to add lair.cafe repo file (proceeding anyway)"
|
|
||||||
return 0
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
# The .repo file ships enabled=0; flip it on. Cheap, idempotent.
|
|
||||||
if ! ssh "${host}" sudo dnf config-manager setopt \
|
|
||||||
lair-cafe-unstable.enabled=1 >/dev/null 2>&1; then
|
|
||||||
echo "[${host}] WARNING: failed to enable lair-cafe-unstable (proceeding anyway)"
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# Ensure libcudnn.so.9 is resolvable on the remote host so the
|
|
||||||
# neuron binary (built with --features cudnn) doesn't fail at startup
|
|
||||||
# with "cannot open shared object file: No such file or directory".
|
|
||||||
#
|
|
||||||
# Probes ldconfig first — if cuDNN was installed manually (.tar/.run
|
|
||||||
# install), it'll be cached by ldconfig and we don't touch it.
|
|
||||||
# Otherwise adds NVIDIA's RHEL9 CUDA repo (the Fedora 43 CUDA repo
|
|
||||||
# doesn't ship cuDNN packages — only the RHEL9 one does) and installs
|
|
||||||
# libcudnn9-cuda-13.
|
|
||||||
ensure_cudnn_runtime() {
|
|
||||||
local host="$1"
|
|
||||||
if ssh "${host}" "ldconfig -p | grep -q libcudnn.so.9" 2>/dev/null; then
|
|
||||||
return 0
|
|
||||||
fi
|
|
||||||
echo "[${host}] installing cuDNN runtime"
|
|
||||||
if ! ssh "${host}" "test -f /etc/yum.repos.d/cuda-rhel9-x86_64.repo" 2>/dev/null; then
|
|
||||||
if ! ssh "${host}" sudo dnf config-manager addrepo \
|
|
||||||
--from-repofile=https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
|
|
||||||
>/dev/null 2>&1; then
|
|
||||||
echo "[${host}] WARNING: failed to add rhel9 CUDA repo (proceeding anyway)"
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
if ! ssh "${host}" sudo dnf install -y libcudnn9-cuda-13 >/dev/null 2>&1; then
|
|
||||||
echo "[${host}] WARNING: failed to install libcudnn9-cuda-13"
|
|
||||||
echo "[${host}] neuron may fail to start; install cuDNN manually if so"
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# True when the named package needs to be installed or upgraded on the
|
|
||||||
# remote host — either it's not present, or a newer version exists in
|
|
||||||
# the repo. False only when the installed version is current.
|
|
||||||
#
|
|
||||||
# `dnf check-update <pkg>` returns 0 when the package isn't installed
|
|
||||||
# at all (there's nothing to update), so we have to probe with rpm -q
|
|
||||||
# first to distinguish "absent" from "current". Other dnf failures
|
|
||||||
# collapse into "needs update" so the subsequent install step surfaces
|
|
||||||
# the real diagnostic rather than this check swallowing it.
|
|
||||||
needs_update() {
|
|
||||||
local host="$1" pkg="$2"
|
|
||||||
# Not installed → needs work.
|
|
||||||
if ! ssh "${host}" "rpm -q ${pkg}" >/dev/null 2>&1; then
|
|
||||||
return 0
|
|
||||||
fi
|
|
||||||
# Installed; ask dnf whether the repo has something newer.
|
|
||||||
if ssh "${host}" sudo dnf check-update --refresh -q "${pkg}" >/dev/null 2>&1; then
|
|
||||||
return 1
|
|
||||||
else
|
|
||||||
return 0
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# True if the named package is currently installed on the remote host.
|
|
||||||
# Used to decide between `dnf install` (fresh) and `dnf upgrade` (stale):
|
|
||||||
# dnf5's `install` is a no-op when the package is already present at
|
|
||||||
# any version — it does NOT auto-upgrade to the latest available — so
|
|
||||||
# the wrong command silently leaves the host on an old build.
|
|
||||||
is_installed() {
|
|
||||||
local host="$1" pkg="$2"
|
|
||||||
ssh "${host}" "rpm -q ${pkg}" >/dev/null 2>&1
|
|
||||||
}
|
|
||||||
|
|
||||||
# Install or upgrade the named package on the remote, picking the
|
|
||||||
# right dnf verb based on the installed-or-not state. Returns 0 with
|
|
||||||
# dnf's combined stdout/stderr captured in __DNF_OUTPUT__ on success,
|
|
||||||
# and 1 with the same captured output on failure.
|
|
||||||
__DNF_OUTPUT__=""
|
|
||||||
install_or_upgrade() {
|
|
||||||
local host="$1" pkg="$2"
|
|
||||||
local cmd
|
|
||||||
if is_installed "${host}" "${pkg}"; then
|
|
||||||
cmd="upgrade"
|
|
||||||
else
|
|
||||||
cmd="install"
|
|
||||||
fi
|
|
||||||
if __DNF_OUTPUT__=$(
|
|
||||||
ssh "${host}" sudo dnf "${cmd}" --refresh --allowerasing -y "${pkg}" 2>&1
|
|
||||||
); then
|
|
||||||
return 0
|
|
||||||
else
|
|
||||||
return 1
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# cortex (gateway)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
ensure_lair_repo "${cortex_host}"
|
|
||||||
cortex_nvr=$(installed_nvr "${cortex_host}" cortex)
|
|
||||||
if needs_update "${cortex_host}" cortex; then
|
|
||||||
echo "[${cortex_host}] cortex update available (current: ${cortex_nvr})"
|
|
||||||
# Stop the service only if the unit file exists — fresh installs
|
|
||||||
# don't have it, and `systemctl stop` on a missing unit returns
|
|
||||||
# non-zero, which would otherwise short-circuit the install branch
|
|
||||||
# under set -e.
|
|
||||||
if ssh "${cortex_host}" "[ ! -f /usr/lib/systemd/system/cortex.service ] || sudo systemctl stop cortex.service"; then
|
|
||||||
echo "[${cortex_host}] stopped cortex service"
|
|
||||||
if install_or_upgrade "${cortex_host}" cortex; then
|
|
||||||
cortex_nvr=$(installed_nvr "${cortex_host}" cortex)
|
|
||||||
echo "[${cortex_host}] installed/upgraded cortex to ${cortex_nvr}"
|
|
||||||
else
|
|
||||||
echo "[${cortex_host}] failed to install/upgrade cortex:"
|
|
||||||
echo "${__DNF_OUTPUT__}" | sed "s/^/[${cortex_host}] /"
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
echo "[${cortex_host}] failed to stop cortex service"
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
echo "[${cortex_host}] cortex is up to date (${cortex_nvr})"
|
|
||||||
ssh "${cortex_host}" sudo systemctl stop cortex.service || true
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Sync cortex.toml whether the package was upgraded or not — the config
|
|
||||||
# can change without a package bump.
|
|
||||||
if rsync \
|
|
||||||
--archive \
|
|
||||||
--compress \
|
|
||||||
--rsync-path 'sudo rsync' \
|
|
||||||
--chown root:root \
|
|
||||||
--chmod 644 \
|
|
||||||
"${REPO_DIR}/cortex.toml" \
|
|
||||||
"${cortex_host}:/etc/cortex/cortex.toml"; then
|
|
||||||
echo "[${cortex_host}] sync'd cortex.toml"
|
|
||||||
else
|
|
||||||
echo "[${cortex_host}] failed to sync cortex.toml"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Sync models.toml on the same lifecycle as cortex.toml — operator-owned,
|
|
||||||
# gitignored, drives /v1/models catalogue × topology resolution.
|
|
||||||
if [[ -f "${REPO_DIR}/models.toml" ]]; then
|
|
||||||
if rsync \
|
|
||||||
--archive \
|
|
||||||
--compress \
|
|
||||||
--rsync-path 'sudo rsync' \
|
|
||||||
--chown root:root \
|
|
||||||
--chmod 644 \
|
|
||||||
"${REPO_DIR}/models.toml" \
|
|
||||||
"${cortex_host}:/etc/cortex/models.toml"; then
|
|
||||||
echo "[${cortex_host}] sync'd models.toml"
|
|
||||||
else
|
|
||||||
echo "[${cortex_host}] failed to sync models.toml"
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
echo "[${cortex_host}] no local models.toml — leaving /etc/cortex/models.toml untouched"
|
|
||||||
fi
|
|
||||||
|
|
||||||
ssh "${cortex_host}" sudo systemctl daemon-reload
|
|
||||||
if ssh "${cortex_host}" systemctl is-active --quiet cortex.service; then
|
|
||||||
echo "[${cortex_host}] cortex service is active"
|
|
||||||
elif ssh "${cortex_host}" sudo systemctl start cortex.service; then
|
|
||||||
echo "[${cortex_host}] started cortex service"
|
|
||||||
else
|
|
||||||
echo "[${cortex_host}] failed to start cortex service"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# neuron (per-host, flavour from manifest)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
for entry in "${neuron_entries[@]}"; do
|
|
||||||
IFS=$'\t' read -r neuron_host neuron_flavour <<< "${entry}"
|
|
||||||
package="helexa-neuron-${neuron_flavour}"
|
|
||||||
# First dot-component of the host keys the per-host config file
|
|
||||||
# under asset/neuron/<short>.toml. A host listed in the manifest
|
|
||||||
# without a corresponding config still deploys (the package's
|
|
||||||
# default /etc/neuron/neuron.toml stays in place; no pre-warm).
|
|
||||||
short_host="${neuron_host%%.*}"
|
|
||||||
host_config="${REPO_DIR}/asset/neuron/${short_host}.toml"
|
|
||||||
|
|
||||||
ensure_lair_repo "${neuron_host}"
|
|
||||||
ensure_cudnn_runtime "${neuron_host}"
|
|
||||||
neuron_nvr=$(installed_nvr "${neuron_host}" "${package}")
|
|
||||||
|
|
||||||
# Stop the service unconditionally before any reconfig step.
|
|
||||||
# `default_models` is read at activation, so a config change without
|
|
||||||
# a bounce silently leaves the host on the previous pre-warm set.
|
|
||||||
# Same shape as the cortex flow above. The `[ ! -f … ]` guard skips
|
|
||||||
# the stop on a fresh install where the unit file isn't there yet.
|
|
||||||
if ssh "${neuron_host}" "[ ! -f /usr/lib/systemd/system/neuron.service ] || sudo systemctl stop neuron.service"; then
|
|
||||||
echo "[${neuron_host}] stopped neuron service"
|
|
||||||
else
|
|
||||||
echo "[${neuron_host}] failed to stop neuron service (continuing)"
|
|
||||||
fi
|
|
||||||
|
|
||||||
if needs_update "${neuron_host}" "${package}"; then
|
|
||||||
echo "[${neuron_host}] ${package} update available (current: ${neuron_nvr})"
|
|
||||||
# --allowerasing lets dnf swap out a previously-installed
|
|
||||||
# bare helexa-neuron or a different flavour without manual
|
|
||||||
# intervention. The Conflicts: clauses in the spec ensure
|
|
||||||
# only one flavour is ever resident.
|
|
||||||
if install_or_upgrade "${neuron_host}" "${package}"; then
|
|
||||||
neuron_nvr=$(installed_nvr "${neuron_host}" "${package}")
|
|
||||||
echo "[${neuron_host}] installed/upgraded ${package} to ${neuron_nvr}"
|
|
||||||
# Ensure firewalld allows neuron port
|
|
||||||
ssh "${neuron_host}" "sudo firewall-cmd --query-service=helexa-neuron --quiet 2>/dev/null || sudo firewall-cmd --add-service=helexa-neuron --permanent && sudo firewall-cmd --reload" 2>/dev/null || true
|
|
||||||
else
|
|
||||||
echo "[${neuron_host}] failed to install ${package}:"
|
|
||||||
echo "${__DNF_OUTPUT__}" | sed "s/^/[${neuron_host}] /"
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
echo "[${neuron_host}] ${package} is up to date (${neuron_nvr})"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Sync per-host neuron.toml — drives default_models pre-warm so
|
|
||||||
# `/v1/models` on the gateway exposes the host's headline model
|
|
||||||
# immediately after the service comes back up. Missing per-host
|
|
||||||
# config leaves the package's installed neuron.toml untouched.
|
|
||||||
if [[ -f "${host_config}" ]]; then
|
|
||||||
if rsync \
|
|
||||||
--archive \
|
|
||||||
--compress \
|
|
||||||
--rsync-path 'sudo rsync' \
|
|
||||||
--chown root:root \
|
|
||||||
--chmod 644 \
|
|
||||||
"${host_config}" \
|
|
||||||
"${neuron_host}:/etc/neuron/neuron.toml"; then
|
|
||||||
echo "[${neuron_host}] sync'd asset/neuron/${short_host}.toml"
|
|
||||||
else
|
|
||||||
echo "[${neuron_host}] failed to sync neuron.toml"
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
echo "[${neuron_host}] no asset/neuron/${short_host}.toml — leaving /etc/neuron/neuron.toml untouched"
|
|
||||||
fi
|
|
||||||
|
|
||||||
if ssh "${neuron_host}" "sudo systemctl daemon-reload && sudo systemctl start neuron.service"; then
|
|
||||||
echo "[${neuron_host}] started neuron service"
|
|
||||||
else
|
|
||||||
echo "[${neuron_host}] failed to start neuron service"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
151
script/infra-setup.sh
Executable file
151
script/infra-setup.sh
Executable file
@@ -0,0 +1,151 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
#
|
||||||
|
# One-time setup for the gitea_ci deploy-user on every host that the
|
||||||
|
# .gitea/workflows/deploy.yml workflow targets:
|
||||||
|
# - create the gitea_ci system user (if missing)
|
||||||
|
# - install the runner's pubkey into ~gitea_ci/.ssh/authorized_keys
|
||||||
|
# - install the appropriate /etc/sudoers.d/helexa_gitea_ci sudoers
|
||||||
|
# drop-in (cortex flavour on the gateway, neuron flavour on each
|
||||||
|
# neuron host)
|
||||||
|
#
|
||||||
|
# Idempotent — safe to re-run after fleet changes. Continues past
|
||||||
|
# unreachable hosts so a single offline node doesn't block the rest.
|
||||||
|
|
||||||
|
script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
repo_path="$(cd "${script_dir}/.." && pwd)"
|
||||||
|
|
||||||
|
cortex_host=hanzalova.internal
|
||||||
|
neuron_hosts=(
|
||||||
|
beast.hanzalova.internal
|
||||||
|
benjy.hanzalova.internal
|
||||||
|
quadbrat.hanzalova.internal
|
||||||
|
)
|
||||||
|
|
||||||
|
pubkey="${HOME}/.ssh/id_gitea_ci.pub"
|
||||||
|
if [[ ! -f "${pubkey}" ]]; then
|
||||||
|
echo "fatal: ${pubkey} not found" >&2
|
||||||
|
echo " generate with: ssh-keygen -t ed25519 -f ${pubkey%.pub} -C gitea_ci" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Provision gitea_ci on every host (cortex + all neurons).
|
||||||
|
#
|
||||||
|
# Quoting matters here: "${cortex_host} ${neuron_hosts[@]}" inside a
|
||||||
|
# single pair of quotes collapses the scalar and the first array
|
||||||
|
# element into one space-joined word, which then word-splits when
|
||||||
|
# referenced unquoted in `ssh ${host}` — and ssh interprets the second
|
||||||
|
# hostname as the remote command. Separate quoting fixes it.
|
||||||
|
for host in "${cortex_host}" "${neuron_hosts[@]}"; do
|
||||||
|
echo "==> ${host}"
|
||||||
|
if ! ssh "${host}" '
|
||||||
|
set -eu
|
||||||
|
if id -u gitea_ci >/dev/null 2>&1; then
|
||||||
|
echo " gitea_ci user already present"
|
||||||
|
else
|
||||||
|
sudo useradd --system --create-home \
|
||||||
|
--home-dir /var/lib/gitea_ci --shell /bin/bash gitea_ci
|
||||||
|
echo " gitea_ci user created"
|
||||||
|
fi
|
||||||
|
# `sudo install` runs as root (not as gitea_ci), which avoids
|
||||||
|
# the "sudo: unknown user gitea_ci" failure seen immediately
|
||||||
|
# after useradd — NSS caching lags briefly and `sudo -u` cant
|
||||||
|
# resolve the just-created user, but `install -o` does its
|
||||||
|
# own fresh lookup.
|
||||||
|
sudo install -d -o gitea_ci -g gitea_ci -m 0700 \
|
||||||
|
/var/lib/gitea_ci/.ssh
|
||||||
|
# Grant journal read access so the deploy workflow can capture
|
||||||
|
# `journalctl -u <unit> -I` after a service start without
|
||||||
|
# needing a sudoers entry. Idempotent — usermod -aG on an
|
||||||
|
# already-member is a no-op.
|
||||||
|
sudo usermod -aG systemd-journal gitea_ci
|
||||||
|
'; then
|
||||||
|
echo " failed to provision gitea_ci — skipping ${host}"
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
if rsync \
|
||||||
|
--archive \
|
||||||
|
--compress \
|
||||||
|
--chown gitea_ci:gitea_ci \
|
||||||
|
--chmod 0600 \
|
||||||
|
--rsync-path 'sudo rsync' \
|
||||||
|
"${pubkey}" \
|
||||||
|
"${host}:/var/lib/gitea_ci/.ssh/authorized_keys"; then
|
||||||
|
echo " authorized_keys synced"
|
||||||
|
else
|
||||||
|
echo " failed to sync authorized_keys"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Install /etc/sudoers.d/helexa_gitea_ci on a host and verify the
|
||||||
|
# resulting file parses, so a typo cant lock root out.
|
||||||
|
install_sudoers() {
|
||||||
|
local host="$1" template="$2"
|
||||||
|
echo "==> ${host}: installing /etc/sudoers.d/helexa_gitea_ci"
|
||||||
|
if ! rsync \
|
||||||
|
--archive \
|
||||||
|
--compress \
|
||||||
|
--chown root:root \
|
||||||
|
--chmod 0440 \
|
||||||
|
--rsync-path 'sudo rsync' \
|
||||||
|
"${template}" \
|
||||||
|
"${host}:/etc/sudoers.d/helexa_gitea_ci"; then
|
||||||
|
echo " failed to sync ${template##*/}"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
if ssh "${host}" 'sudo visudo -cf /etc/sudoers.d/helexa_gitea_ci' \
|
||||||
|
>/dev/null; then
|
||||||
|
echo " installed and verified"
|
||||||
|
else
|
||||||
|
echo " WARNING: visudo rejected the installed file — review on ${host}"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
install_sudoers "${cortex_host}" \
|
||||||
|
"${repo_path}/asset/sudoers.d/cortex-host.conf"
|
||||||
|
|
||||||
|
for neuron_host in "${neuron_hosts[@]}"; do
|
||||||
|
install_sudoers "${neuron_host}" \
|
||||||
|
"${repo_path}/asset/sudoers.d/neuron-host.conf"
|
||||||
|
done
|
||||||
|
|
||||||
|
# Push application config to the fleet. The deploy workflow is
|
||||||
|
# scoped to package install + service restart; config changes ride
|
||||||
|
# along with this script instead, since:
|
||||||
|
# - cortex.toml and models.toml are gitignored (operator-owned, may
|
||||||
|
# include secrets), so CI never sees them
|
||||||
|
# - asset/neuron/<short>.toml is tracked but iterating locally is
|
||||||
|
# faster than pushing a commit and waiting for build-prerelease
|
||||||
|
# to roll over
|
||||||
|
# Missing source files are skipped silently — re-run after editing.
|
||||||
|
sync_config() {
|
||||||
|
local host="$1" src="$2" dst="$3"
|
||||||
|
if [[ ! -f "${src}" ]]; then
|
||||||
|
echo " ${src##*/} not present locally — skipping"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
if rsync \
|
||||||
|
--archive \
|
||||||
|
--compress \
|
||||||
|
--chown root:root \
|
||||||
|
--chmod 0644 \
|
||||||
|
--rsync-path 'sudo rsync' \
|
||||||
|
"${src}" \
|
||||||
|
"${host}:${dst}"; then
|
||||||
|
echo " ${src##*/} → ${host}:${dst}"
|
||||||
|
else
|
||||||
|
echo " failed to sync ${src##*/} to ${host}"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "==> ${cortex_host}: syncing gateway configs"
|
||||||
|
sync_config "${cortex_host}" "${repo_path}/cortex.toml" /etc/cortex/cortex.toml
|
||||||
|
sync_config "${cortex_host}" "${repo_path}/models.toml" /etc/cortex/models.toml
|
||||||
|
|
||||||
|
for neuron_host in "${neuron_hosts[@]}"; do
|
||||||
|
short="${neuron_host%%.*}"
|
||||||
|
echo "==> ${neuron_host}: syncing per-host neuron config"
|
||||||
|
sync_config "${neuron_host}" \
|
||||||
|
"${repo_path}/asset/neuron/${short}.toml" \
|
||||||
|
/etc/neuron/neuron.toml
|
||||||
|
done
|
||||||
Reference in New Issue
Block a user