Compare commits

...

3 Commits

Author SHA1 Message Date
e3ca2ae502 Add OCR text extraction with separate labels table
Florence-2 OCR variant (florence-2-base-ocr) extracts text from images
using the same ONNX model files as florence-2-base with a different
prompt. Results are stored in a new `labels` table, keeping OCR
artifacts separate from descriptive captions. Empty labels mark images
that were scanned but contained no text, preventing re-scanning.

- Add labels table (migration 0014) with Label entity and data layer
- Florence-2 model now accepts configurable prompts via load_with_prompt
- OCR model variant reuses florence-2-base directory (no symlink needed)
- Main loop routes OCR models to labels table, caption models to captions
- Quick search includes label matches; advanced search adds label filter
- UI search page adds "Filter by image text" input
- Deploy script starts florence-2-base-ocr service alongside others

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-07 09:45:29 +03:00
6dd99e3b0a Use florence-2-base captions as image alt text in gallery viewer
Add GET /api/galleries/{id}/captions endpoint that returns captions
for all images in a gallery. Gallery UI fetches captions on mount and
uses them as alt text on main image and thumbnails, falling back to
filename when no caption is available.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-07 09:17:15 +03:00
c450fa7f89 Add image captioning service with auto-download from HuggingFace
New standalone rbv-caption binary that generates image captions using
ONNX models. Fetches images via CDN, writes captions to a new captions
table, and integrates with search (both quick and advanced modes).

Supported models:
- vit-gpt2: ViT encoder + GPT-2 decoder (auto-download from Xenova)
- florence-2-base: Florence-2 4-stage pipeline using fine-tuned variant
  from onnx-community (auto-download)
- blip-base, git-base: manual ONNX export required

Key implementation details:
- Florence-2 task tokens are natural language prompts, not special tokens
- Uses non-merged decoder ONNX models (no KV cache) for simplicity
- Systemd template unit for deploying multiple models concurrently
- Deploy script targets quadbrat for GPU inference

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-07 08:45:31 +03:00
27 changed files with 2113 additions and 15 deletions

594
Cargo.lock generated
View File

@@ -430,6 +430,12 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "chrono"
version = "0.4.44"
@@ -529,6 +535,19 @@ dependencies = [
"crossbeam-utils",
]
[[package]]
name = "console"
version = "0.15.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8"
dependencies = [
"encode_unicode",
"libc",
"once_cell",
"unicode-width",
"windows-sys 0.59.0",
]
[[package]]
name = "const-oid"
version = "0.9.6"
@@ -541,6 +560,16 @@ version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b"
[[package]]
name = "core-foundation"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation"
version = "0.10.1"
@@ -777,6 +806,27 @@ dependencies = [
"subtle",
]
[[package]]
name = "dirs"
version = "6.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e"
dependencies = [
"dirs-sys",
]
[[package]]
name = "dirs-sys"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab"
dependencies = [
"libc",
"option-ext",
"redox_users",
"windows-sys 0.61.2",
]
[[package]]
name = "displaydoc"
version = "0.2.5"
@@ -809,6 +859,21 @@ dependencies = [
"serde",
]
[[package]]
name = "encode_unicode"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
[[package]]
name = "encoding_rs"
version = "0.8.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3"
dependencies = [
"cfg-if",
]
[[package]]
name = "equivalent"
version = "1.0.2"
@@ -967,6 +1032,21 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "futures"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.32"
@@ -1011,6 +1091,17 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
[[package]]
name = "futures-macro"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.32"
@@ -1029,8 +1120,10 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
@@ -1055,8 +1148,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi",
"wasm-bindgen",
]
[[package]]
@@ -1066,9 +1161,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"r-efi 5.3.0",
"wasip2",
"wasm-bindgen",
]
[[package]]
@@ -1156,12 +1253,42 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
[[package]]
name = "hex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "hf-hub"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97"
dependencies = [
"dirs",
"futures",
"http",
"indicatif",
"libc",
"log",
"native-tls",
"num_cpus",
"rand 0.9.2",
"reqwest",
"serde",
"serde_json",
"thiserror 2.0.18",
"tokio",
"ureq 2.12.1",
"windows-sys 0.60.2",
]
[[package]]
name = "hkdf"
version = "0.12.4"
@@ -1266,6 +1393,40 @@ dependencies = [
"pin-utils",
"smallvec",
"tokio",
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.27.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58"
dependencies = [
"http",
"hyper",
"hyper-util",
"rustls",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
"webpki-roots 1.0.6",
]
[[package]]
name = "hyper-tls"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes",
"http-body-util",
"hyper",
"hyper-util",
"native-tls",
"tokio",
"tokio-native-tls",
"tower-service",
]
[[package]]
@@ -1274,13 +1435,23 @@ version = "0.1.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0"
dependencies = [
"base64 0.22.1",
"bytes",
"futures-channel",
"futures-util",
"http",
"http-body",
"hyper",
"ipnet",
"libc",
"percent-encoding",
"pin-project-lite",
"socket2",
"system-configuration",
"tokio",
"tower-service",
"tracing",
"windows-registry",
]
[[package]]
@@ -1462,6 +1633,35 @@ dependencies = [
"serde_core",
]
[[package]]
name = "indicatif"
version = "0.17.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235"
dependencies = [
"console",
"number_prefix",
"portable-atomic",
"unicode-width",
"web-time",
]
[[package]]
name = "ipnet"
version = "2.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2"
[[package]]
name = "iri-string"
version = "0.7.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25e659a4bb38e810ebc252e53b5814ff908a8c58c2a9ce2fae1bbec24cbf4e20"
dependencies = [
"memchr",
"serde",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.2"
@@ -1579,6 +1779,12 @@ version = "0.4.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
[[package]]
name = "lru-slab"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
[[package]]
name = "lzma-rust2"
version = "0.15.7"
@@ -1854,6 +2060,22 @@ dependencies = [
"libm",
]
[[package]]
name = "num_cpus"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b"
dependencies = [
"hermit-abi",
"libc",
]
[[package]]
name = "number_prefix"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]]
name = "oid-registry"
version = "0.7.1"
@@ -1941,6 +2163,12 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "option-ext"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "ort"
version = "2.0.0-rc.12"
@@ -1951,7 +2179,7 @@ dependencies = [
"ort-sys",
"smallvec",
"tracing",
"ureq",
"ureq 3.3.0",
]
[[package]]
@@ -1962,7 +2190,7 @@ checksum = "d7b497d21a8b6fbb4b5a544f8fadb77e801a09ae0add9e411d31c6f89e3c1e90"
dependencies = [
"hmac-sha256",
"lzma-rust2",
"ureq",
"ureq 3.3.0",
]
[[package]]
@@ -2172,6 +2400,61 @@ version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
[[package]]
name = "quinn"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20"
dependencies = [
"bytes",
"cfg_aliases",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"socket2",
"thiserror 2.0.18",
"tokio",
"tracing",
"web-time",
]
[[package]]
name = "quinn-proto"
version = "0.11.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098"
dependencies = [
"bytes",
"getrandom 0.3.4",
"lru-slab",
"rand 0.9.2",
"ring",
"rustc-hash",
"rustls",
"rustls-pki-types",
"slab",
"thiserror 2.0.18",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.52.0",
]
[[package]]
name = "quote"
version = "1.0.45"
@@ -2330,6 +2613,26 @@ dependencies = [
"x509-parser",
]
[[package]]
name = "rbv-caption"
version = "0.1.0"
dependencies = [
"anyhow",
"clap",
"hf-hub",
"image",
"ort",
"rbv-data",
"rbv-entity",
"rbv-hash",
"reqwest",
"sqlx",
"tokenizers",
"tokio",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "rbv-cli"
version = "0.1.0"
@@ -2459,6 +2762,17 @@ dependencies = [
"bitflags",
]
[[package]]
name = "redox_users"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4e608c6638b9c18977b00b475ac1f28d14e84b27d8d42f70e0bf1e3dec127ac"
dependencies = [
"getrandom 0.2.17",
"libredox",
"thiserror 2.0.18",
]
[[package]]
name = "regex"
version = "1.12.3"
@@ -2488,6 +2802,53 @@ version = "0.8.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a"
[[package]]
name = "reqwest"
version = "0.12.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147"
dependencies = [
"base64 0.22.1",
"bytes",
"encoding_rs",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-tls",
"hyper-util",
"js-sys",
"log",
"mime",
"native-tls",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-native-tls",
"tokio-rustls",
"tokio-util",
"tower",
"tower-http",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"webpki-roots 1.0.6",
]
[[package]]
name = "ring"
version = "0.17.14"
@@ -2522,6 +2883,12 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rustc-hash"
version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe"
[[package]]
name = "rusticata-macros"
version = "4.1.0"
@@ -2575,6 +2942,7 @@ version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd"
dependencies = [
"web-time",
"zeroize",
]
@@ -2624,7 +2992,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d"
dependencies = [
"bitflags",
"core-foundation",
"core-foundation 0.10.1",
"core-foundation-sys",
"libc",
"security-framework-sys",
@@ -3091,6 +3459,9 @@ name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
dependencies = [
"futures-core",
]
[[package]]
name = "synstructure"
@@ -3103,6 +3474,27 @@ dependencies = [
"syn",
]
[[package]]
name = "system-configuration"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b"
dependencies = [
"bitflags",
"core-foundation 0.9.4",
"system-configuration-sys",
]
[[package]]
name = "system-configuration-sys"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "tempfile"
version = "3.27.0"
@@ -3296,6 +3688,16 @@ dependencies = [
"syn",
]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.26.4"
@@ -3361,12 +3763,14 @@ dependencies = [
"http-body-util",
"http-range-header",
"httpdate",
"iri-string",
"mime",
"mime_guess",
"percent-encoding",
"pin-project-lite",
"tokio",
"tokio-util",
"tower",
"tower-layer",
"tower-service",
"tracing",
@@ -3459,6 +3863,12 @@ dependencies = [
"tracing-serde",
]
[[package]]
name = "try-lock"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "typenum"
version = "1.19.0"
@@ -3513,6 +3923,12 @@ version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode-width"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
[[package]]
name = "unicode-xid"
version = "0.2.6"
@@ -3531,6 +3947,26 @@ version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "ureq"
version = "2.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d"
dependencies = [
"base64 0.22.1",
"flate2",
"log",
"native-tls",
"once_cell",
"rustls",
"rustls-pki-types",
"serde",
"serde_json",
"socks",
"url",
"webpki-roots 0.26.11",
]
[[package]]
name = "ureq"
version = "3.3.0"
@@ -3621,6 +4057,15 @@ version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]]
name = "want"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e"
dependencies = [
"try-lock",
]
[[package]]
name = "wasi"
version = "0.11.1+wasi-snapshot-preview1"
@@ -3664,6 +4109,20 @@ dependencies = [
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9c5522b3a28661442748e09d40924dfb9ca614b21c00d3fd135720e48b67db8"
dependencies = [
"cfg-if",
"futures-util",
"js-sys",
"once_cell",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.114"
@@ -3718,6 +4177,19 @@ dependencies = [
"wasmparser",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "wasmparser"
version = "0.244.0"
@@ -3730,6 +4202,26 @@ dependencies = [
"semver",
]
[[package]]
name = "web-sys"
version = "0.3.91"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "854ba17bb104abfb26ba36da9729addc7ce7f06f5c0f90f3c391f8461cca21f9"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "webpki-root-certs"
version = "1.0.6"
@@ -3836,6 +4328,17 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
[[package]]
name = "windows-registry"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720"
dependencies = [
"windows-link",
"windows-result",
"windows-strings",
]
[[package]]
name = "windows-result"
version = "0.4.1"
@@ -3872,6 +4375,24 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.59.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.60.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb"
dependencies = [
"windows-targets 0.53.5",
]
[[package]]
name = "windows-sys"
version = "0.61.2"
@@ -3905,13 +4426,30 @@ dependencies = [
"windows_aarch64_gnullvm 0.52.6",
"windows_aarch64_msvc 0.52.6",
"windows_i686_gnu 0.52.6",
"windows_i686_gnullvm",
"windows_i686_gnullvm 0.52.6",
"windows_i686_msvc 0.52.6",
"windows_x86_64_gnu 0.52.6",
"windows_x86_64_gnullvm 0.52.6",
"windows_x86_64_msvc 0.52.6",
]
[[package]]
name = "windows-targets"
version = "0.53.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3"
dependencies = [
"windows-link",
"windows_aarch64_gnullvm 0.53.1",
"windows_aarch64_msvc 0.53.1",
"windows_i686_gnu 0.53.1",
"windows_i686_gnullvm 0.53.1",
"windows_i686_msvc 0.53.1",
"windows_x86_64_gnu 0.53.1",
"windows_x86_64_gnullvm 0.53.1",
"windows_x86_64_msvc 0.53.1",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.5"
@@ -3924,6 +4462,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.5"
@@ -3936,6 +4480,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_aarch64_msvc"
version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006"
[[package]]
name = "windows_i686_gnu"
version = "0.48.5"
@@ -3948,12 +4498,24 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
[[package]]
name = "windows_i686_gnu"
version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_gnullvm"
version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c"
[[package]]
name = "windows_i686_msvc"
version = "0.48.5"
@@ -3966,6 +4528,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_i686_msvc"
version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.5"
@@ -3978,6 +4546,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]]
name = "windows_x86_64_gnu"
version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.5"
@@ -3990,6 +4564,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.5"
@@ -4002,6 +4582,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "windows_x86_64_msvc"
version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650"
[[package]]
name = "wit-bindgen"
version = "0.51.0"

View File

@@ -12,6 +12,7 @@ members = [
"crates/rbv-search",
"crates/rbv-cli",
"crates/rbv-api",
"crates/rbv-caption",
]
[workspace.package]
@@ -80,3 +81,9 @@ rbv-cluster = { path = "crates/rbv-cluster" }
rbv-ingest = { path = "crates/rbv-ingest" }
rbv-auth = { path = "crates/rbv-auth" }
rbv-search = { path = "crates/rbv-search" }
# HTTP client
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
# HuggingFace model hub
hf-hub = { version = "0.4", features = ["tokio"] }

View File

@@ -0,0 +1,17 @@
[Unit]
Description=rbv caption (%i)
After=network.target
ConditionFileIsExecutable=/usr/local/bin/rbv-caption
[Service]
Environment=RUST_LOG=debug,ort=off,sqlx::query=off,hyper_util=off
ExecStart=/usr/local/bin/rbv-caption \
--model %i \
--model-dir /var/lib/rbv/models \
--database postgres://rbv:password@gramathea.kosherinata.internal:4432/rbv \
--cdn-map /tank/data/rbv/vault=https://rbv.internal/vault \
--batch-size 100
Restart=always
[Install]
WantedBy=multi-user.target

View File

@@ -18,6 +18,7 @@ pub fn router() -> Router<AppState> {
.route("/{id}/persons", get(get_gallery_persons))
.route("/{id}/tags", axum::routing::post(add_tag))
.route("/{id}/subjects", axum::routing::post(add_subject))
.route("/{id}/captions", get(get_gallery_captions))
}
#[derive(Deserialize)]
@@ -277,3 +278,24 @@ async fn add_subject(
.ok_or_else(|| ApiError::not_found("gallery not found"))?;
Ok((StatusCode::OK, Json(gallery.into())))
}
// ── Captions ────────────────────────────────────────────────────────────────
#[derive(Deserialize)]
pub struct CaptionQuery {
#[serde(default = "default_caption_model")]
pub model: String,
}
fn default_caption_model() -> String { "florence-2-base".into() }
async fn get_gallery_captions(
State(state): State<AppState>,
Path(id): Path<String>,
Query(q): Query<CaptionQuery>,
) -> ApiResult<Json<std::collections::HashMap<String, String>>> {
let bytes = rbv_hash::from_hex(&id)
.map_err(|_| ApiError::bad_request("invalid gallery id"))?;
let gid = rbv_entity::GalleryId(bytes);
let captions = rbv_data::caption::gallery_captions(&state.pool, &gid, &q.model).await?;
Ok(Json(captions.into_iter().collect()))
}

View File

@@ -15,6 +15,8 @@ pub struct SearchRequest {
pub tag: Option<String>,
pub subject: Option<String>,
pub person_name: Option<String>,
pub caption: Option<String>,
pub label: Option<String>,
#[serde(default = "default_limit")]
pub limit: i64,
}
@@ -41,6 +43,8 @@ async fn search(
tag: req.tag.as_deref(),
subject: req.subject.as_deref(),
person_name: req.person_name.as_deref(),
caption: req.caption.as_deref(),
label: req.label.as_deref(),
limit: req.limit,
};

View File

@@ -0,0 +1,25 @@
[package]
name = "rbv-caption"
version.workspace = true
edition.workspace = true
license.workspace = true
[[bin]]
name = "rbv-caption"
path = "src/main.rs"
[dependencies]
rbv-entity = { workspace = true }
rbv-data = { workspace = true }
rbv-hash = { workspace = true }
clap = { workspace = true }
sqlx = { workspace = true }
tokio = { workspace = true }
anyhow = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
ort = { workspace = true }
tokenizers = { workspace = true }
image = { workspace = true }
reqwest = { workspace = true }
hf-hub = { workspace = true }

View File

@@ -0,0 +1,114 @@
use anyhow::{Context, Result};
use std::path::{Path, PathBuf};
use tracing::info;
/// HuggingFace repo and file mapping for a captioning model.
struct ModelSource {
repo_id: &'static str,
files: &'static [(&'static str, &'static str)], // (remote_path, local_filename)
}
fn model_source(model_name: &str) -> Option<ModelSource> {
match model_name {
"florence-2-base" => Some(ModelSource {
repo_id: "onnx-community/Florence-2-base-ft",
files: &[
("onnx/vision_encoder.onnx", "vision_encoder.onnx"),
("onnx/encoder_model.onnx", "encoder.onnx"),
("onnx/embed_tokens.onnx", "embed_tokens.onnx"),
("onnx/decoder_model.onnx", "decoder.onnx"),
("tokenizer.json", "tokenizer.json"),
],
}),
"vit-gpt2" => Some(ModelSource {
repo_id: "Xenova/vit-gpt2-image-captioning",
files: &[
("onnx/encoder_model.onnx", "encoder.onnx"),
("onnx/decoder_model.onnx", "decoder.onnx"),
("tokenizer.json", "tokenizer.json"),
],
}),
// OCR variant reuses the same florence-2-base files.
"florence-2-base-ocr" => Some(ModelSource {
repo_id: "onnx-community/Florence-2-base-ft",
files: &[
("onnx/vision_encoder.onnx", "vision_encoder.onnx"),
("onnx/encoder_model.onnx", "encoder.onnx"),
("onnx/embed_tokens.onnx", "embed_tokens.onnx"),
("onnx/decoder_model.onnx", "decoder.onnx"),
("tokenizer.json", "tokenizer.json"),
],
}),
_ => None,
}
}
/// Ensure all required model files are present in the local model directory.
/// Downloads from HuggingFace if any are missing.
/// Returns the model directory path.
pub async fn ensure_model(model_name: &str, model_dir: &Path) -> Result<PathBuf> {
// OCR variant shares the same model files as the base caption model.
let dir_name = match model_name {
"florence-2-base-ocr" => "florence-2-base",
other => other,
};
let local_dir = model_dir.join("captioning").join(dir_name);
let source = match model_source(model_name) {
Some(s) => s,
None => {
// No auto-download — check if files were placed manually.
if local_dir.exists() {
return Ok(local_dir);
}
anyhow::bail!(
"No auto-download source for model '{model_name}'. \
Manually place model files in {}. \
Models with auto-download: florence-2-base, vit-gpt2",
local_dir.display()
);
}
};
// Check if all required files already exist.
let all_present = source.files.iter().all(|(_, local_name)| local_dir.join(local_name).exists());
if all_present {
return Ok(local_dir);
}
info!(
"Downloading model '{}' from HuggingFace repo '{}'...",
model_name, source.repo_id
);
std::fs::create_dir_all(&local_dir)
.with_context(|| format!("Failed to create directory: {}", local_dir.display()))?;
let api = hf_hub::api::tokio::Api::new()
.context("Failed to initialize HuggingFace API")?;
let repo = api.model(source.repo_id.to_string());
for (remote_path, local_name) in source.files {
let local_path = local_dir.join(local_name);
if local_path.exists() {
info!(" {} already exists, skipping", local_name);
continue;
}
info!(" Downloading {} ...", remote_path);
let cached_path = repo.get(remote_path).await
.with_context(|| format!("Failed to download {remote_path} from {}", source.repo_id))?;
// hf-hub caches files in its own directory. Copy to our model dir.
std::fs::copy(&cached_path, &local_path)
.with_context(|| format!(
"Failed to copy {} to {}",
cached_path.display(),
local_path.display()
))?;
info!(" Saved {}", local_path.display());
}
info!("Model '{}' ready.", model_name);
Ok(local_dir)
}

View File

@@ -0,0 +1,191 @@
mod download;
mod models;
use std::path::PathBuf;
use anyhow::Result;
use clap::Parser;
use tracing::info;
use rbv_entity::{Caption, ImageId, Label};
#[derive(Parser)]
#[command(name = "rbv-caption", about = "Generate image captions using ONNX models")]
struct Args {
/// PostgreSQL connection string
#[arg(long)]
database: String,
/// Model to use (blip-base, git-base, florence-2-base, florence-2-base-ocr, vit-gpt2)
#[arg(long)]
model: String,
/// Path to model directory (expects captioning/{model}/ subdirectory)
#[arg(long)]
model_dir: PathBuf,
/// CDN mapping in the form fs_prefix=url_prefix
#[arg(long)]
cdn_map: String,
/// Number of images per batch
#[arg(long, default_value = "100")]
batch_size: i64,
/// Number of concurrent download+caption tasks
#[arg(long, default_value = "1")]
concurrency: usize,
}
struct CdnMap {
fs_prefix: String,
url_prefix: String,
}
impl CdnMap {
fn resolve(&self, gallery_path: &str, filename: &str) -> Option<String> {
if gallery_path.starts_with(&self.fs_prefix) {
let rel = &gallery_path[self.fs_prefix.len()..];
Some(format!("{}{rel}/{filename}", self.url_prefix))
} else {
None
}
}
}
fn is_ocr_model(model_name: &str) -> bool {
model_name.ends_with("-ocr")
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
let args = Args::parse();
let (fs_prefix, url_prefix) = args.cdn_map.split_once('=')
.ok_or_else(|| anyhow::anyhow!("--cdn-map must be in the form fs_prefix=url_prefix"))?;
let cdn = CdnMap {
fs_prefix: fs_prefix.trim_end_matches('/').to_string(),
url_prefix: url_prefix.trim_end_matches('/').to_string(),
};
// Download model from HuggingFace if not present locally.
download::ensure_model(&args.model, &args.model_dir).await?;
info!("Loading model: {}", args.model);
let model: std::sync::Arc<dyn models::CaptionModel> = std::sync::Arc::from(
models::load_model(&args.model, &args.model_dir)?
);
info!("Model loaded.");
let pool = rbv_data::connect(&args.database, 2).await?;
let ocr = is_ocr_model(&args.model);
let total = if ocr {
rbv_data::label::count_unlabelled(&pool, &args.model).await?
} else {
rbv_data::caption::count_uncaptioned(&pool, &args.model).await?
};
let task = if ocr { "unlabelled" } else { "uncaptioned" };
info!("{total} {task} images for model '{}'", args.model);
if total == 0 {
info!("Nothing to do.");
return Ok(());
}
let http = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()?;
let mut processed = 0u64;
let mut errors = 0u64;
loop {
let batch = if ocr {
rbv_data::label::unlabelled_images(&pool, &args.model, args.batch_size).await?
} else {
rbv_data::caption::uncaptioned_images(&pool, &args.model, args.batch_size).await?
};
if batch.is_empty() {
break;
}
let batch_len = batch.len();
for (image_id_hex, gallery_path, filename) in &batch {
let result = process_image(
&http, &model, &pool, &cdn, &args.model, ocr,
image_id_hex, gallery_path, filename,
).await;
match result {
Ok(()) => processed += 1,
Err(e) => {
tracing::warn!("Error for {image_id_hex}: {e}");
errors += 1;
}
}
}
info!(
"Progress: {processed}/{total} {task}, {errors} errors ({batch_len} in last batch)"
);
}
info!("Done. {processed} {task}, {errors} errors.");
Ok(())
}
async fn process_image(
http: &reqwest::Client,
model: &std::sync::Arc<dyn models::CaptionModel>,
pool: &sqlx::PgPool,
cdn: &CdnMap,
model_name: &str,
ocr: bool,
image_id_hex: &str,
gallery_path: &str,
filename: &str,
) -> Result<()> {
let url = cdn.resolve(gallery_path, filename)
.ok_or_else(|| anyhow::anyhow!(
"CDN mapping does not match gallery path: {gallery_path}"
))?;
let response = http.get(&url).send().await?;
if !response.status().is_success() {
anyhow::bail!("HTTP {} fetching {url}", response.status());
}
let image_bytes = response.bytes().await?;
// Run inference (blocking — ONNX inference is CPU-bound)
let bytes = image_bytes.to_vec();
let model_clone = model.clone();
let text = tokio::task::spawn_blocking(move || {
model_clone.caption(&bytes)
}).await??;
let id_bytes = rbv_hash::from_hex(image_id_hex)?;
let image_id = ImageId(id_bytes);
if ocr {
rbv_data::label::upsert_label(pool, &Label {
image_id,
model: model_name.to_string(),
label: text,
}).await?;
} else {
rbv_data::caption::upsert_caption(pool, &Caption {
image_id,
model: model_name.to_string(),
caption: text,
}).await?;
}
Ok(())
}

View File

@@ -0,0 +1,145 @@
use anyhow::{Context, Result};
use image::GenericImageView;
use ort::session::Session;
use ort::value::Tensor;
use std::path::Path;
use std::sync::Mutex;
use tokenizers::Tokenizer;
use super::CaptionModel;
const IMAGE_SIZE: u32 = 384;
const MAX_TOKENS: usize = 50;
const MEAN: [f32; 3] = [0.48145466, 0.4578275, 0.40821073];
const STD: [f32; 3] = [0.26862954, 0.26130258, 0.27577711];
pub struct BlipBase {
encoder: Mutex<Session>,
decoder: Mutex<Session>,
tokenizer: Tokenizer,
}
impl BlipBase {
pub fn load(model_dir: &Path) -> Result<Self> {
let encoder_path = model_dir.join("encoder.onnx");
let decoder_path = model_dir.join("decoder.onnx");
let tokenizer_path = model_dir.join("tokenizer.json");
for p in [&encoder_path, &decoder_path, &tokenizer_path] {
anyhow::ensure!(p.exists(), "Missing model file: {}", p.display());
}
let encoder = Session::builder()
.map_err(|e| anyhow::anyhow!("{e}"))?
.commit_from_file(&encoder_path)
.map_err(|e| anyhow::anyhow!("Failed to load BLIP encoder: {e}"))?;
let decoder = Session::builder()
.map_err(|e| anyhow::anyhow!("{e}"))?
.commit_from_file(&decoder_path)
.map_err(|e| anyhow::anyhow!("Failed to load BLIP decoder: {e}"))?;
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {e}"))?;
Ok(Self {
encoder: Mutex::new(encoder),
decoder: Mutex::new(decoder),
tokenizer,
})
}
fn preprocess_image(&self, image_bytes: &[u8]) -> Result<Vec<f32>> {
let img = image::load_from_memory(image_bytes)
.context("Failed to decode image")?;
let img = img.resize_exact(IMAGE_SIZE, IMAGE_SIZE, image::imageops::FilterType::Lanczos3);
let mut pixels = vec![0f32; 3 * IMAGE_SIZE as usize * IMAGE_SIZE as usize];
for y in 0..IMAGE_SIZE {
for x in 0..IMAGE_SIZE {
let pixel = img.get_pixel(x, y);
for c in 0..3usize {
let idx = c * (IMAGE_SIZE as usize * IMAGE_SIZE as usize)
+ y as usize * IMAGE_SIZE as usize
+ x as usize;
pixels[idx] = (pixel[c] as f32 / 255.0 - MEAN[c]) / STD[c];
}
}
}
Ok(pixels)
}
}
impl CaptionModel for BlipBase {
fn caption(&self, image_bytes: &[u8]) -> Result<String> {
let pixels = self.preprocess_image(image_bytes)?;
let pixel_input = Tensor::<f32>::from_array(
([1usize, 3, IMAGE_SIZE as usize, IMAGE_SIZE as usize], pixels),
).map_err(|e| anyhow::anyhow!("{e}"))?;
// Encode image
let (enc_data, enc_shape) = {
let mut enc_guard = self.encoder.lock().unwrap();
let encoder_outputs = enc_guard.run(ort::inputs![pixel_input])
.map_err(|e| anyhow::anyhow!("{e}"))?;
let (shape, data) = encoder_outputs[0].try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("{e}"))?;
(data.to_vec(), shape.iter().map(|&d| d as usize).collect::<Vec<_>>())
};
// Autoregressive decoding
let bos_id = self.tokenizer.token_to_id("[CLS]").unwrap_or(101) as i64;
let sep_id = self.tokenizer.token_to_id("[SEP]").unwrap_or(102) as i64;
let mut token_ids: Vec<i64> = vec![bos_id];
for _ in 0..MAX_TOKENS {
let seq_len = token_ids.len();
let input_ids = Tensor::<i64>::from_array(
([1usize, seq_len], token_ids.clone()),
).map_err(|e| anyhow::anyhow!("{e}"))?;
let attention_mask = Tensor::<i64>::from_array(
([1usize, seq_len], vec![1i64; seq_len]),
).map_err(|e| anyhow::anyhow!("{e}"))?;
let enc_hidden = Tensor::<f32>::from_array(
(enc_shape.clone(), enc_data.clone()),
).map_err(|e| anyhow::anyhow!("{e}"))?;
let (logits_data, vocab_size) = {
let mut dec_guard = self.decoder.lock().unwrap();
let decoder_outputs = dec_guard.run(ort::inputs![input_ids, attention_mask, enc_hidden])
.map_err(|e| anyhow::anyhow!("{e}"))?;
let (shape, data) = decoder_outputs[0].try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("{e}"))?;
let vs = *shape.last().unwrap_or(&0) as usize;
(data.to_vec(), vs)
};
let last_pos = seq_len - 1;
let offset = last_pos * vocab_size;
let mut best_id = 0usize;
let mut best_score = f32::NEG_INFINITY;
for v in 0..vocab_size {
let score = logits_data[offset + v];
if score > best_score {
best_score = score;
best_id = v;
}
}
let next_id = best_id as i64;
if next_id == sep_id || next_id == 0 {
break;
}
token_ids.push(next_id);
}
let output_ids: Vec<u32> = token_ids[1..].iter().map(|&id| id as u32).collect();
let caption = self.tokenizer.decode(&output_ids, true)
.map_err(|e| anyhow::anyhow!("Tokenizer decode failed: {e}"))?;
Ok(caption.trim().to_string())
}
}

View File

@@ -0,0 +1,232 @@
use anyhow::{Context, Result};
use image::GenericImageView;
use ort::session::Session;
use ort::value::Tensor;
use std::path::Path;
use std::sync::Mutex;
use tokenizers::Tokenizer;
use super::CaptionModel;
const IMAGE_SIZE: u32 = 768;
const MAX_TOKENS: usize = 120;
const MEAN: [f32; 3] = [0.485, 0.456, 0.406];
const STD: [f32; 3] = [0.229, 0.224, 0.225];
pub struct Florence2Base {
vision_encoder: Mutex<Session>,
encoder: Mutex<Session>,
embed_tokens: Mutex<Session>,
decoder: Mutex<Session>,
tokenizer: Tokenizer,
prompt: String,
}
impl Florence2Base {
pub fn load(model_dir: &Path) -> Result<Self> {
Self::load_with_prompt(model_dir, "Describe with a paragraph what is shown in the image.")
}
pub fn load_ocr(model_dir: &Path) -> Result<Self> {
Self::load_with_prompt(model_dir, "What is the text in the image?")
}
pub fn load_with_prompt(model_dir: &Path, prompt: &str) -> Result<Self> {
let vision_enc_path = model_dir.join("vision_encoder.onnx");
let encoder_path = model_dir.join("encoder.onnx");
let embed_path = model_dir.join("embed_tokens.onnx");
let decoder_path = model_dir.join("decoder.onnx");
let tokenizer_path = model_dir.join("tokenizer.json");
for p in [&vision_enc_path, &encoder_path, &embed_path, &decoder_path, &tokenizer_path] {
anyhow::ensure!(p.exists(), "Missing model file: {}", p.display());
}
let load = |path: &Path, name: &str| -> Result<Session> {
Session::builder()
.map_err(|e| anyhow::anyhow!("{e}"))
.and_then(|mut b| b.commit_from_file(path)
.map_err(|e| anyhow::anyhow!("Failed to load Florence-2 {name}: {e}")))
};
let vision_encoder = load(&vision_enc_path, "vision encoder")?;
let encoder = load(&encoder_path, "text encoder")?;
let embed_tokens = load(&embed_path, "embed_tokens")?;
let decoder = load(&decoder_path, "decoder")?;
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {e}"))?;
Ok(Self {
vision_encoder: Mutex::new(vision_encoder),
encoder: Mutex::new(encoder),
embed_tokens: Mutex::new(embed_tokens),
decoder: Mutex::new(decoder),
tokenizer,
prompt: prompt.to_string(),
})
}
fn preprocess_image(&self, image_bytes: &[u8]) -> Result<Vec<f32>> {
let img = image::load_from_memory(image_bytes)
.context("Failed to decode image")?;
let img = img.resize_exact(IMAGE_SIZE, IMAGE_SIZE, image::imageops::FilterType::Lanczos3);
let mut pixels = vec![0f32; 3 * IMAGE_SIZE as usize * IMAGE_SIZE as usize];
for y in 0..IMAGE_SIZE {
for x in 0..IMAGE_SIZE {
let pixel = img.get_pixel(x, y);
for c in 0..3usize {
let idx = c * (IMAGE_SIZE as usize * IMAGE_SIZE as usize)
+ y as usize * IMAGE_SIZE as usize
+ x as usize;
pixels[idx] = (pixel[c] as f32 / 255.0 - MEAN[c]) / STD[c];
}
}
}
Ok(pixels)
}
}
impl CaptionModel for Florence2Base {
fn caption(&self, image_bytes: &[u8]) -> Result<String> {
let pixels = self.preprocess_image(image_bytes)?;
let pixel_input = Tensor::<f32>::from_array(
([1usize, 3, IMAGE_SIZE as usize, IMAGE_SIZE as usize], pixels),
).map_err(|e| anyhow::anyhow!("{e}"))?;
// Step 1: Run vision encoder on image pixels.
let (image_features, image_shape) = {
let mut guard = self.vision_encoder.lock().unwrap();
let outputs = guard.run(ort::inputs![pixel_input])
.map_err(|e| anyhow::anyhow!("vision_encoder: {e}"))?;
let (shape, data) = outputs[0].try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("vision_encoder output: {e}"))?;
(data.to_vec(), shape.iter().map(|&d| d as usize).collect::<Vec<_>>())
};
// Step 2: Tokenize the task prompt and get text embeddings.
// Florence-2 maps task tokens to natural language prompts before tokenization.
let encoding = self.tokenizer.encode(self.prompt.as_str(), false)
.map_err(|e| anyhow::anyhow!("Tokenizer encode failed: {e}"))?;
let prompt_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
let prompt_len = prompt_ids.len();
let (text_embeds, text_shape) = {
let input_ids = Tensor::<i64>::from_array(
([1usize, prompt_len], prompt_ids.clone()),
).map_err(|e| anyhow::anyhow!("{e}"))?;
let mut guard = self.embed_tokens.lock().unwrap();
let outputs = guard.run(ort::inputs![input_ids])
.map_err(|e| anyhow::anyhow!("embed_tokens: {e}"))?;
let (shape, data) = outputs[0].try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("embed_tokens output: {e}"))?;
(data.to_vec(), shape.iter().map(|&d| d as usize).collect::<Vec<_>>())
};
// Step 3: Concatenate image features and text embeddings, run encoder.
// inputs_embeds shape: [1, image_seq_len + text_seq_len, hidden_dim]
let hidden_dim = image_shape.last().copied().unwrap_or(0);
let image_seq_len = image_shape.get(1).copied().unwrap_or(0);
let text_seq_len = text_shape.get(1).copied().unwrap_or(0);
let total_seq_len = image_seq_len + text_seq_len;
let mut combined = Vec::with_capacity(total_seq_len * hidden_dim);
combined.extend_from_slice(&image_features);
combined.extend_from_slice(&text_embeds);
let attention_mask = Tensor::<i64>::from_array(
([1usize, total_seq_len], vec![1i64; total_seq_len]),
).map_err(|e| anyhow::anyhow!("{e}"))?;
let inputs_embeds = Tensor::<f32>::from_array(
([1usize, total_seq_len, hidden_dim], combined),
).map_err(|e| anyhow::anyhow!("{e}"))?;
let (enc_hidden, enc_shape) = {
let mut guard = self.encoder.lock().unwrap();
let outputs = guard.run(ort::inputs![attention_mask, inputs_embeds])
.map_err(|e| anyhow::anyhow!("encoder: {e}"))?;
let (shape, data) = outputs[0].try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("encoder output: {e}"))?;
(data.to_vec(), shape.iter().map(|&d| d as usize).collect::<Vec<_>>())
};
// Step 4: Autoregressive decoding.
let eos_id = self.tokenizer.token_to_id("</s>").unwrap_or(2) as i64;
// decoder_start_token_id = 2 (</s>) per generation_config.json
let decoder_start_id = eos_id;
let mut token_ids: Vec<i64> = vec![decoder_start_id];
for _ in 0..MAX_TOKENS {
let seq_len = token_ids.len();
// Decoder expects: encoder_attention_mask, encoder_hidden_states, inputs_embeds
// inputs_embeds = embed_tokens(decoder token IDs)
let dec_input_ids = Tensor::<i64>::from_array(
([1usize, seq_len], token_ids.clone()),
).map_err(|e| anyhow::anyhow!("{e}"))?;
let (dec_embeds, dec_embed_shape) = {
let mut guard = self.embed_tokens.lock().unwrap();
let outputs = guard.run(ort::inputs![dec_input_ids])
.map_err(|e| anyhow::anyhow!("embed_tokens (decoder): {e}"))?;
let (shape, data) = outputs[0].try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("embed_tokens output: {e}"))?;
(data.to_vec(), shape.iter().map(|&d| d as usize).collect::<Vec<_>>())
};
let enc_attn_mask = Tensor::<i64>::from_array(
([1usize, total_seq_len], vec![1i64; total_seq_len]),
).map_err(|e| anyhow::anyhow!("{e}"))?;
let enc_hidden_tensor = Tensor::<f32>::from_array(
(enc_shape.clone(), enc_hidden.clone()),
).map_err(|e| anyhow::anyhow!("{e}"))?;
let inputs_embeds = Tensor::<f32>::from_array(
(dec_embed_shape, dec_embeds),
).map_err(|e| anyhow::anyhow!("{e}"))?;
let (logits_data, vocab_size) = {
let mut guard = self.decoder.lock().unwrap();
let outputs = guard.run(
ort::inputs![enc_attn_mask, enc_hidden_tensor, inputs_embeds]
).map_err(|e| anyhow::anyhow!("decoder: {e}"))?;
let (shape, data) = outputs[0].try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("decoder output: {e}"))?;
let vs = *shape.last().unwrap_or(&0) as usize;
(data.to_vec(), vs)
};
let last_pos = seq_len - 1;
let offset = last_pos * vocab_size;
let mut best_id = 0usize;
let mut best_score = f32::NEG_INFINITY;
for v in 0..vocab_size {
let score = logits_data[offset + v];
if score > best_score {
best_score = score;
best_id = v;
}
}
let next_id = best_id as i64;
if next_id == eos_id {
break;
}
token_ids.push(next_id);
}
// Skip BOS token for decoding.
let output_ids: Vec<u32> = token_ids[1..].iter().map(|&id| id as u32).collect();
let caption = self.tokenizer.decode(&output_ids, true)
.map_err(|e| anyhow::anyhow!("Tokenizer decode failed: {e}"))?;
let caption = caption.trim();
Ok(caption.to_string())
}
}

View File

@@ -0,0 +1,139 @@
use anyhow::{Context, Result};
use image::GenericImageView;
use ort::session::Session;
use ort::value::Tensor;
use std::path::Path;
use std::sync::Mutex;
use tokenizers::Tokenizer;
use super::CaptionModel;
const IMAGE_SIZE: u32 = 224;
const MAX_TOKENS: usize = 50;
const MEAN: [f32; 3] = [0.48145466, 0.4578275, 0.40821073];
const STD: [f32; 3] = [0.26862954, 0.26130258, 0.27577711];
pub struct GitBase {
encoder: Mutex<Session>,
decoder: Mutex<Session>,
tokenizer: Tokenizer,
}
impl GitBase {
pub fn load(model_dir: &Path) -> Result<Self> {
let encoder_path = model_dir.join("encoder.onnx");
let decoder_path = model_dir.join("decoder.onnx");
let tokenizer_path = model_dir.join("tokenizer.json");
for p in [&encoder_path, &decoder_path, &tokenizer_path] {
anyhow::ensure!(p.exists(), "Missing model file: {}", p.display());
}
let encoder = Session::builder()
.map_err(|e| anyhow::anyhow!("{e}"))?
.commit_from_file(&encoder_path)
.map_err(|e| anyhow::anyhow!("Failed to load GIT encoder: {e}"))?;
let decoder = Session::builder()
.map_err(|e| anyhow::anyhow!("{e}"))?
.commit_from_file(&decoder_path)
.map_err(|e| anyhow::anyhow!("Failed to load GIT decoder: {e}"))?;
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {e}"))?;
Ok(Self {
encoder: Mutex::new(encoder),
decoder: Mutex::new(decoder),
tokenizer,
})
}
fn preprocess_image(&self, image_bytes: &[u8]) -> Result<Vec<f32>> {
let img = image::load_from_memory(image_bytes)
.context("Failed to decode image")?;
let img = img.resize_exact(IMAGE_SIZE, IMAGE_SIZE, image::imageops::FilterType::Lanczos3);
let mut pixels = vec![0f32; 3 * IMAGE_SIZE as usize * IMAGE_SIZE as usize];
for y in 0..IMAGE_SIZE {
for x in 0..IMAGE_SIZE {
let pixel = img.get_pixel(x, y);
for c in 0..3usize {
let idx = c * (IMAGE_SIZE as usize * IMAGE_SIZE as usize)
+ y as usize * IMAGE_SIZE as usize
+ x as usize;
pixels[idx] = (pixel[c] as f32 / 255.0 - MEAN[c]) / STD[c];
}
}
}
Ok(pixels)
}
}
impl CaptionModel for GitBase {
fn caption(&self, image_bytes: &[u8]) -> Result<String> {
let pixels = self.preprocess_image(image_bytes)?;
let pixel_input = Tensor::<f32>::from_array(
([1usize, 3, IMAGE_SIZE as usize, IMAGE_SIZE as usize], pixels),
).map_err(|e| anyhow::anyhow!("{e}"))?;
let (enc_data, enc_shape) = {
let mut enc_guard = self.encoder.lock().unwrap();
let encoder_outputs = enc_guard.run(ort::inputs![pixel_input])
.map_err(|e| anyhow::anyhow!("{e}"))?;
let (shape, data) = encoder_outputs[0].try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("{e}"))?;
(data.to_vec(), shape.iter().map(|&d| d as usize).collect::<Vec<_>>())
};
let bos_id = self.tokenizer.token_to_id("[CLS]").unwrap_or(101) as i64;
let eos_id = self.tokenizer.token_to_id("[SEP]").unwrap_or(102) as i64;
let mut token_ids: Vec<i64> = vec![bos_id];
for _ in 0..MAX_TOKENS {
let seq_len = token_ids.len();
let input_ids = Tensor::<i64>::from_array(
([1usize, seq_len], token_ids.clone()),
).map_err(|e| anyhow::anyhow!("{e}"))?;
let enc_hidden = Tensor::<f32>::from_array(
(enc_shape.clone(), enc_data.clone()),
).map_err(|e| anyhow::anyhow!("{e}"))?;
let (logits_data, vocab_size) = {
let mut dec_guard = self.decoder.lock().unwrap();
let decoder_outputs = dec_guard.run(ort::inputs![input_ids, enc_hidden])
.map_err(|e| anyhow::anyhow!("{e}"))?;
let (shape, data) = decoder_outputs[0].try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("{e}"))?;
let vs = *shape.last().unwrap_or(&0) as usize;
(data.to_vec(), vs)
};
let last_pos = seq_len - 1;
let offset = last_pos * vocab_size;
let mut best_id = 0usize;
let mut best_score = f32::NEG_INFINITY;
for v in 0..vocab_size {
let score = logits_data[offset + v];
if score > best_score {
best_score = score;
best_id = v;
}
}
let next_id = best_id as i64;
if next_id == eos_id || next_id == 0 {
break;
}
token_ids.push(next_id);
}
let output_ids: Vec<u32> = token_ids[1..].iter().map(|&id| id as u32).collect();
let caption = self.tokenizer.decode(&output_ids, true)
.map_err(|e| anyhow::anyhow!("Tokenizer decode failed: {e}"))?;
Ok(caption.trim().to_string())
}
}

View File

@@ -0,0 +1,40 @@
pub mod blip;
pub mod git;
pub mod florence;
pub mod vit_gpt2;
use anyhow::Result;
use std::path::Path;
pub trait CaptionModel: Send + Sync {
fn caption(&self, image_bytes: &[u8]) -> Result<String>;
}
pub fn load_model(model_name: &str, model_dir: &Path) -> Result<Box<dyn CaptionModel>> {
match model_name {
"blip-base" | "git-base" | "florence-2-base" | "vit-gpt2" => {
let base = model_dir.join("captioning").join(model_name);
if !base.exists() {
anyhow::bail!("Model directory not found: {}", base.display());
}
match model_name {
"blip-base" => Ok(Box::new(blip::BlipBase::load(&base)?)),
"git-base" => Ok(Box::new(git::GitBase::load(&base)?)),
"florence-2-base" => Ok(Box::new(florence::Florence2Base::load(&base)?)),
"vit-gpt2" => Ok(Box::new(vit_gpt2::VitGpt2::load(&base)?)),
_ => unreachable!(),
}
}
"florence-2-base-ocr" => {
// OCR variant reuses the same ONNX files as florence-2-base.
let base = model_dir.join("captioning").join("florence-2-base");
if !base.exists() {
anyhow::bail!("Model directory not found: {}", base.display());
}
Ok(Box::new(florence::Florence2Base::load_ocr(&base)?))
}
_ => anyhow::bail!(
"Unknown model: {model_name}. Supported: blip-base, git-base, florence-2-base, florence-2-base-ocr, vit-gpt2"
),
}
}

View File

@@ -0,0 +1,145 @@
use anyhow::{Context, Result};
use image::GenericImageView;
use ort::session::Session;
use ort::value::Tensor;
use std::path::Path;
use std::sync::Mutex;
use tokenizers::Tokenizer;
use super::CaptionModel;
const IMAGE_SIZE: u32 = 224;
const MAX_TOKENS: usize = 50;
// ViT uses ImageNet normalization
const MEAN: [f32; 3] = [0.5, 0.5, 0.5];
const STD: [f32; 3] = [0.5, 0.5, 0.5];
pub struct VitGpt2 {
encoder: Mutex<Session>,
decoder: Mutex<Session>,
tokenizer: Tokenizer,
}
impl VitGpt2 {
pub fn load(model_dir: &Path) -> Result<Self> {
let encoder_path = model_dir.join("encoder.onnx");
let decoder_path = model_dir.join("decoder.onnx");
let tokenizer_path = model_dir.join("tokenizer.json");
for p in [&encoder_path, &decoder_path, &tokenizer_path] {
anyhow::ensure!(p.exists(), "Missing model file: {}", p.display());
}
let encoder = Session::builder()
.map_err(|e| anyhow::anyhow!("{e}"))?
.commit_from_file(&encoder_path)
.map_err(|e| anyhow::anyhow!("Failed to load ViT-GPT2 encoder: {e}"))?;
let decoder = Session::builder()
.map_err(|e| anyhow::anyhow!("{e}"))?
.commit_from_file(&decoder_path)
.map_err(|e| anyhow::anyhow!("Failed to load ViT-GPT2 decoder: {e}"))?;
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {e}"))?;
Ok(Self {
encoder: Mutex::new(encoder),
decoder: Mutex::new(decoder),
tokenizer,
})
}
fn preprocess_image(&self, image_bytes: &[u8]) -> Result<Vec<f32>> {
let img = image::load_from_memory(image_bytes)
.context("Failed to decode image")?;
let img = img.resize_exact(IMAGE_SIZE, IMAGE_SIZE, image::imageops::FilterType::Lanczos3);
let mut pixels = vec![0f32; 3 * IMAGE_SIZE as usize * IMAGE_SIZE as usize];
for y in 0..IMAGE_SIZE {
for x in 0..IMAGE_SIZE {
let pixel = img.get_pixel(x, y);
for c in 0..3usize {
let idx = c * (IMAGE_SIZE as usize * IMAGE_SIZE as usize)
+ y as usize * IMAGE_SIZE as usize
+ x as usize;
pixels[idx] = (pixel[c] as f32 / 255.0 - MEAN[c]) / STD[c];
}
}
}
Ok(pixels)
}
}
impl CaptionModel for VitGpt2 {
fn caption(&self, image_bytes: &[u8]) -> Result<String> {
let pixels = self.preprocess_image(image_bytes)?;
let pixel_input = Tensor::<f32>::from_array(
([1usize, 3, IMAGE_SIZE as usize, IMAGE_SIZE as usize], pixels),
).map_err(|e| anyhow::anyhow!("{e}"))?;
let (enc_data, enc_shape) = {
let mut enc_guard = self.encoder.lock().unwrap();
let encoder_outputs = enc_guard.run(ort::inputs![pixel_input])
.map_err(|e| anyhow::anyhow!("{e}"))?;
let (shape, data) = encoder_outputs[0].try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("{e}"))?;
(data.to_vec(), shape.iter().map(|&d| d as usize).collect::<Vec<_>>())
};
// GPT-2 uses BOS token to start generation
let bos_id = self.tokenizer.token_to_id("<|endoftext|>").unwrap_or(50256) as i64;
let eos_id = bos_id; // GPT-2 uses same token for BOS and EOS
let mut token_ids: Vec<i64> = vec![bos_id];
for _ in 0..MAX_TOKENS {
let seq_len = token_ids.len();
let input_ids = Tensor::<i64>::from_array(
([1usize, seq_len], token_ids.clone()),
).map_err(|e| anyhow::anyhow!("{e}"))?;
let enc_hidden = Tensor::<f32>::from_array(
(enc_shape.clone(), enc_data.clone()),
).map_err(|e| anyhow::anyhow!("{e}"))?;
let (logits_data, vocab_size) = {
let mut dec_guard = self.decoder.lock().unwrap();
let decoder_outputs = dec_guard.run(ort::inputs![input_ids, enc_hidden])
.map_err(|e| anyhow::anyhow!("{e}"))?;
let (shape, data) = decoder_outputs[0].try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("{e}"))?;
let vs = *shape.last().unwrap_or(&0) as usize;
(data.to_vec(), vs)
};
let last_pos = seq_len - 1;
let offset = last_pos * vocab_size;
let mut best_id = 0usize;
let mut best_score = f32::NEG_INFINITY;
for v in 0..vocab_size {
let score = logits_data[offset + v];
if score > best_score {
best_score = score;
best_id = v;
}
}
let next_id = best_id as i64;
// Stop on EOS or after generating enough tokens
if next_id == eos_id && token_ids.len() > 1 {
break;
}
if next_id == 0 {
break;
}
token_ids.push(next_id);
}
let output_ids: Vec<u32> = token_ids[1..].iter().map(|&id| id as u32).collect();
let caption = self.tokenizer.decode(&output_ids, true)
.map_err(|e| anyhow::anyhow!("Tokenizer decode failed: {e}"))?;
Ok(caption.trim().to_string())
}
}

View File

@@ -0,0 +1,148 @@
use anyhow::Result;
use sqlx::{PgPool, Row};
use rbv_entity::{Caption, Gallery, GalleryId, ImageId};
pub async fn upsert_caption(pool: &PgPool, caption: &Caption) -> Result<()> {
sqlx::query(
r#"
INSERT INTO captions (image_id, model, caption)
VALUES ($1, $2, $3)
ON CONFLICT (image_id, model) DO UPDATE SET
caption = EXCLUDED.caption,
created_at = now()
"#,
)
.bind(caption.image_id.as_bytes())
.bind(&caption.model)
.bind(&caption.caption)
.execute(pool)
.await?;
Ok(())
}
pub async fn get_captions(pool: &PgPool, image_id: &ImageId) -> Result<Vec<Caption>> {
let rows = sqlx::query(
"SELECT image_id, model, caption FROM captions WHERE image_id = $1 ORDER BY model",
)
.bind(image_id.as_bytes())
.fetch_all(pool)
.await?;
Ok(rows.iter().map(|r| {
let id_bytes: Vec<u8> = r.get("image_id");
Caption {
image_id: ImageId(id_bytes.try_into().expect("32-byte id")),
model: r.get("model"),
caption: r.get("caption"),
}
}).collect())
}
/// Fetch a batch of images that have no caption for the given model,
/// returning (image_id_hex, gallery_path, filename) tuples.
pub async fn uncaptioned_images(
pool: &PgPool,
model: &str,
batch_size: i64,
) -> Result<Vec<(String, String, String)>> {
let rows = sqlx::query(
r#"
SELECT encode(gi.image_id, 'hex') AS image_id_hex, g.path, gi.filename
FROM gallery_images gi
JOIN galleries g ON g.id = gi.gallery_id
WHERE NOT EXISTS (
SELECT 1 FROM captions c
WHERE c.image_id = gi.image_id AND c.model = $1
)
LIMIT $2
"#,
)
.bind(model)
.bind(batch_size)
.fetch_all(pool)
.await?;
Ok(rows.iter().map(|r| {
(
r.get::<String, _>("image_id_hex"),
r.get::<String, _>("path"),
r.get::<String, _>("filename"),
)
}).collect())
}
pub async fn count_uncaptioned(pool: &PgPool, model: &str) -> Result<i64> {
let row = sqlx::query(
r#"
SELECT COUNT(*) AS count
FROM gallery_images gi
WHERE NOT EXISTS (
SELECT 1 FROM captions c
WHERE c.image_id = gi.image_id AND c.model = $1
)
"#,
)
.bind(model)
.fetch_one(pool)
.await?;
Ok(row.get("count"))
}
/// Get captions for all images in a gallery, for a specific model.
/// Returns (image_id_hex, caption) pairs.
pub async fn gallery_captions(
pool: &PgPool,
gallery_id: &GalleryId,
model: &str,
) -> Result<Vec<(String, String)>> {
let rows = sqlx::query(
r#"
SELECT encode(c.image_id, 'hex') AS image_id_hex, c.caption
FROM captions c
JOIN gallery_images gi ON gi.image_id = c.image_id
WHERE gi.gallery_id = $1 AND c.model = $2
"#,
)
.bind(gallery_id.as_bytes())
.bind(model)
.fetch_all(pool)
.await?;
Ok(rows.iter().map(|r| {
(r.get::<String, _>("image_id_hex"), r.get::<String, _>("caption"))
}).collect())
}
/// Search for galleries containing images with captions matching the query.
pub async fn search_captions(pool: &PgPool, query: &str, limit: i64) -> Result<Vec<Gallery>> {
let rows = sqlx::query(
r#"
SELECT DISTINCT g.id, g.source_id, g.collection, g.source_name,
g.source_url, g.subjects, g.tags, g.path
FROM galleries g
JOIN gallery_images gi ON gi.gallery_id = g.id
JOIN captions c ON c.image_id = gi.image_id
WHERE c.caption ILIKE '%' || $1 || '%'
ORDER BY g.collection, g.source_name
LIMIT $2
"#,
)
.bind(query)
.bind(limit)
.fetch_all(pool)
.await?;
Ok(rows.iter().map(|r| {
let id_bytes: Vec<u8> = r.get("id");
Gallery {
id: GalleryId(id_bytes.try_into().expect("32-byte id")),
source_id: r.get::<i64, _>("source_id") as u64,
collection: r.get("collection"),
source_name: r.get("source_name"),
source_url: r.get("source_url"),
subjects: r.get("subjects"),
tags: r.get("tags"),
path: r.get::<String, _>("path").into(),
}
}).collect())
}

View File

@@ -0,0 +1,127 @@
use anyhow::Result;
use sqlx::{PgPool, Row};
use rbv_entity::{Gallery, GalleryId, Label};
pub async fn upsert_label(pool: &PgPool, label: &Label) -> Result<()> {
sqlx::query(
r#"
INSERT INTO labels (image_id, model, label)
VALUES ($1, $2, $3)
ON CONFLICT (image_id, model) DO UPDATE SET
label = EXCLUDED.label,
created_at = now()
"#,
)
.bind(label.image_id.as_bytes())
.bind(&label.model)
.bind(&label.label)
.execute(pool)
.await?;
Ok(())
}
pub async fn unlabelled_images(
pool: &PgPool,
model: &str,
batch_size: i64,
) -> Result<Vec<(String, String, String)>> {
let rows = sqlx::query(
r#"
SELECT encode(gi.image_id, 'hex') AS image_id_hex, g.path, gi.filename
FROM gallery_images gi
JOIN galleries g ON g.id = gi.gallery_id
WHERE NOT EXISTS (
SELECT 1 FROM labels l
WHERE l.image_id = gi.image_id AND l.model = $1
)
LIMIT $2
"#,
)
.bind(model)
.bind(batch_size)
.fetch_all(pool)
.await?;
Ok(rows.iter().map(|r| {
(
r.get::<String, _>("image_id_hex"),
r.get::<String, _>("path"),
r.get::<String, _>("filename"),
)
}).collect())
}
pub async fn count_unlabelled(pool: &PgPool, model: &str) -> Result<i64> {
let row = sqlx::query(
r#"
SELECT COUNT(*) AS count
FROM gallery_images gi
WHERE NOT EXISTS (
SELECT 1 FROM labels l
WHERE l.image_id = gi.image_id AND l.model = $1
)
"#,
)
.bind(model)
.fetch_one(pool)
.await?;
Ok(row.get("count"))
}
/// Search for galleries containing images with labels matching the query.
/// Excludes empty labels (scanned but no text found).
pub async fn search_labels(pool: &PgPool, query: &str, limit: i64) -> Result<Vec<Gallery>> {
let rows = sqlx::query(
r#"
SELECT DISTINCT g.id, g.source_id, g.collection, g.source_name,
g.source_url, g.subjects, g.tags, g.path
FROM galleries g
JOIN gallery_images gi ON gi.gallery_id = g.id
JOIN labels l ON l.image_id = gi.image_id
WHERE l.label <> '' AND l.label ILIKE '%' || $1 || '%'
ORDER BY g.collection, g.source_name
LIMIT $2
"#,
)
.bind(query)
.bind(limit)
.fetch_all(pool)
.await?;
Ok(rows.iter().map(|r| {
let id_bytes: Vec<u8> = r.get("id");
Gallery {
id: GalleryId(id_bytes.try_into().expect("32-byte id")),
source_id: r.get::<i64, _>("source_id") as u64,
collection: r.get("collection"),
source_name: r.get("source_name"),
source_url: r.get("source_url"),
subjects: r.get("subjects"),
tags: r.get("tags"),
path: r.get::<String, _>("path").into(),
}
}).collect())
}
pub async fn gallery_labels(
pool: &PgPool,
gallery_id: &GalleryId,
model: &str,
) -> Result<Vec<(String, String)>> {
let rows = sqlx::query(
r#"
SELECT encode(l.image_id, 'hex') AS image_id_hex, l.label
FROM labels l
JOIN gallery_images gi ON gi.image_id = l.image_id
WHERE gi.gallery_id = $1 AND l.model = $2 AND l.label <> ''
"#,
)
.bind(gallery_id.as_bytes())
.bind(model)
.fetch_all(pool)
.await?;
Ok(rows.iter().map(|r| {
(r.get::<String, _>("image_id_hex"), r.get::<String, _>("label"))
}).collect())
}

View File

@@ -6,6 +6,8 @@ pub mod person;
pub mod clip;
pub mod user;
pub mod favourite;
pub mod caption;
pub mod label;
pub use pool::{connect, run_migrations, rebuild_vector_indexes};
pub use sqlx::PgPool;

View File

@@ -0,0 +1,9 @@
use serde::{Deserialize, Serialize};
use crate::ImageId;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Caption {
pub image_id: ImageId,
pub model: String,
pub caption: String,
}

View File

@@ -0,0 +1,9 @@
use serde::{Deserialize, Serialize};
use crate::ImageId;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Label {
pub image_id: ImageId,
pub model: String,
pub label: String,
}

View File

@@ -4,6 +4,7 @@ pub mod face;
pub mod person;
pub mod clip;
pub mod user;
pub mod caption;
pub use gallery::{Gallery, GalleryId};
pub use image::{Image, GalleryImage, ImageId};
@@ -11,3 +12,7 @@ pub use face::{FaceDetection, FaceId, BoundingBox};
pub use person::{Person, PersonId, PersonName};
pub use clip::ClipEmbedding;
pub use user::{User, Session};
pub use caption::Caption;
pub mod label;
pub use label::Label;

View File

@@ -18,6 +18,8 @@ pub struct SearchParams<'a> {
pub tag: Option<&'a str>,
pub subject: Option<&'a str>,
pub person_name: Option<&'a str>,
pub caption: Option<&'a str>,
pub label: Option<&'a str>,
pub limit: i64,
}
@@ -47,11 +49,13 @@ async fn search_quick(
let fetch_limit = params.limit * 2;
// Run all searches concurrently.
let (tag_galleries, subject_galleries, matched_person_ids, clip_hits) = tokio::try_join!(
let (tag_galleries, subject_galleries, matched_person_ids, clip_hits, caption_galleries, label_galleries) = tokio::try_join!(
rbv_data::gallery::search_galleries_by_tag(pool, query, fetch_limit),
rbv_data::gallery::search_galleries_by_subject(pool, query, fetch_limit),
rbv_data::person::search_persons_by_name(pool, query),
search_by_text(pool, ml, query, fetch_limit),
rbv_data::caption::search_captions(pool, query, fetch_limit),
rbv_data::label::search_labels(pool, query, fetch_limit),
)?;
// Resolve person matches to galleries.
@@ -79,6 +83,16 @@ async fn search_quick(
let entry = scored.entry(id).or_insert_with(|| (g, 0.0));
entry.1 += 1.0;
}
for g in caption_galleries {
let id = g.id.clone();
let entry = scored.entry(id).or_insert_with(|| (g, 0.0));
entry.1 += 1.0;
}
for g in label_galleries {
let id = g.id.clone();
let entry = scored.entry(id).or_insert_with(|| (g, 0.0));
entry.1 += 1.0;
}
// Resolve CLIP image hits to galleries.
for (image_id, score) in clip_hits {
@@ -151,6 +165,30 @@ async fn search_advanced(
filter_sets.push(ids);
}
// Caption filter.
if let Some(caption_query) = params.caption.filter(|c| !c.is_empty()) {
let galleries = rbv_data::caption::search_captions(pool, caption_query, fetch_limit).await?;
let mut ids = HashSet::new();
for g in galleries {
let id = g.id.clone();
ids.insert(id.clone());
gallery_map.entry(id).or_insert_with(|| (g, 1.0));
}
filter_sets.push(ids);
}
// Label (OCR text) filter.
if let Some(label_query) = params.label.filter(|l| !l.is_empty()) {
let galleries = rbv_data::label::search_labels(pool, label_query, fetch_limit).await?;
let mut ids = HashSet::new();
for g in galleries {
let id = g.id.clone();
ids.insert(id.clone());
gallery_map.entry(id).or_insert_with(|| (g, 1.0));
}
filter_sets.push(ids);
}
// Person name filter.
if let Some(name) = params.person_name.filter(|n| !n.is_empty()) {
let person_ids = rbv_data::person::search_persons_by_name(pool, name).await?;

View File

@@ -0,0 +1,9 @@
CREATE TABLE IF NOT EXISTS captions (
image_id BYTEA NOT NULL REFERENCES images(id) ON DELETE CASCADE,
model TEXT NOT NULL,
caption TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
PRIMARY KEY (image_id, model)
);
CREATE INDEX IF NOT EXISTS idx_captions_image ON captions (image_id);

View File

@@ -0,0 +1,9 @@
CREATE TABLE IF NOT EXISTS labels (
image_id BYTEA NOT NULL REFERENCES images(id) ON DELETE CASCADE,
model TEXT NOT NULL,
label TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
PRIMARY KEY (image_id, model)
);
CREATE INDEX IF NOT EXISTS idx_labels_image ON labels (image_id);

View File

@@ -4,6 +4,7 @@ postgres_host=gramathea.kosherinata.internal
api_host=gramathea.kosherinata.internal
index_host=gramathea.kosherinata.internal
ui_host=gramathea.kosherinata.internal
caption_host=quadbrat.hanzalova.internal
deploy_db() {
if rsync \
@@ -109,6 +110,48 @@ deploy_api() {
fi
}
deploy_caption() {
cargo build --release -p rbv-caption
postgres_password=$(grep POSTGRES_PASSWORD asset/quadlet/.env | cut -d '=' -f 2)
for unit in rbv-caption@{florence-2-base,florence-2-base-ocr,vit-gpt2}.service; do
state=$(ssh ${caption_host} "systemctl is-active ${unit} 2>/dev/null")
if [[ ${state} == "active" || ${state} == "activating" ]]; then
if ssh ${caption_host} sudo systemctl stop ${unit}; then
echo "${unit} stopped successfully"
else
echo "failed to stop ${unit}"
exit 1
fi
fi
done
if rsync \
--archive \
--compress \
--rsync-path 'sudo rsync' \
--chown root:root \
--chmod '+x' \
target/release/rbv-caption \
${caption_host}:/usr/local/bin/rbv-caption \
&& rsync \
--archive \
--compress \
--rsync-path 'sudo rsync' \
--chown root:root \
asset/systemd/rbv-caption@.service \
${caption_host}:/etc/systemd/system/rbv-caption@.service \
&& ssh ${caption_host} sudo sed -i -e "s/password/${postgres_password}/" /etc/systemd/system/rbv-caption@.service \
&& ssh ${caption_host} sudo systemctl daemon-reload \
&& ssh ${caption_host} sudo systemctl start --no-block rbv-caption@florence-2-base.service \
&& ssh ${caption_host} sudo systemctl start --no-block rbv-caption@florence-2-base-ocr.service \
&& ssh ${caption_host} sudo systemctl start --no-block rbv-caption@vit-gpt2.service; then
echo "rbv caption deployed successfully"
else
echo "failed to deploy rbv caption"
exit 1
fi
}
deploy_ui() {
if ssh ${ui_host} sudo step certificate verify \
/etc/nginx/tls/rbv/rbv.pem \
@@ -202,15 +245,16 @@ components=("${@}")
if [ ${#components[@]} -eq 0 ]; then
components=(api ui)
elif [ "${components[0]}" = "all" ]; then
components=(db index api ui)
components=(db index api ui caption)
fi
for component in "${components[@]}"; do
case ${component} in
db) deploy_db ;;
index) deploy_index ;;
api) deploy_api ;;
ui) deploy_ui ;;
*) echo "unknown component: ${component}"; exit 1 ;;
db) deploy_db ;;
index) deploy_index ;;
api) deploy_api ;;
ui) deploy_ui ;;
caption) deploy_caption ;;
*) echo "unknown component: ${component}"; exit 1 ;;
esac
done

View File

@@ -66,6 +66,9 @@ export const getGalleryImages = (id: string) => request<GalleryImage[]>(`/galler
export const getGalleryPersons = (id: string) => request<Person[]>(`/galleries/${id}/persons`)
export const getGalleryCaptions = (id: string, model = 'florence-2-base') =>
request<Record<string, string>>(`/galleries/${id}/captions?model=${encodeURIComponent(model)}`)
export const galleriesByTag = (tag: string, page = 1, perPage = 24) =>
request<Gallery[]>(`/galleries/by-tag?tag=${encodeURIComponent(tag)}&page=${page}&per_page=${perPage}`)
@@ -220,6 +223,8 @@ export const search = (params: {
tag?: string
subject?: string
person_name?: string
caption?: string
label?: string
limit?: number
}) =>
request<SearchResult[]>('/search', { method: 'POST', body: JSON.stringify(params) })

View File

@@ -82,6 +82,10 @@
"subjectPlaceholder": "beach, concert...",
"personNameLabel": "Filter by person",
"personNamePlaceholder": "jane, john-doe...",
"captionLabel": "Filter by image caption",
"captionPlaceholder": "birthday cake, group photo...",
"labelLabel": "Filter by image text",
"labelPlaceholder": "watermark, sign text...",
"quickPlaceholder": "Search galleries...",
"quickResults": "Results for \"{{query}}\"",
"submit": "Search",

View File

@@ -2,7 +2,7 @@ import { useEffect, useState, useCallback, useRef } from 'react'
import { useParams, Link } from 'react-router-dom'
import { useTranslation } from 'react-i18next'
import {
getGallery, getGalleryImages, getGalleryPersons,
getGallery, getGalleryImages, getGalleryPersons, getGalleryCaptions,
imageFileUrl, thumbnailUrl,
listFavouriteIds, addFavourite, removeFavourite,
addGalleryTag, listTags,
@@ -28,6 +28,7 @@ export function Gallery() {
const [subjectInput, setSubjectInput] = useState('')
const [allSubjects, setAllSubjects] = useState<string[]>([])
const [showSubjectInput, setShowSubjectInput] = useState(false)
const [captions, setCaptions] = useState<Record<string, string>>({})
const [zoom, setZoom] = useState(1)
const [offset, setOffset] = useState({ x: 0, y: 0 })
const [dragging, setDragging] = useState(false)
@@ -45,6 +46,7 @@ export function Gallery() {
listTags().then(setAllTags).catch(() => setAllTags([]))
listSubjects().then(setAllSubjects).catch(() => setAllSubjects([]))
getConfig().then(setConfig).catch(() => setConfig(null))
getGalleryCaptions(id).then(setCaptions).catch(() => setCaptions({}))
}, [id])
// Reset zoom/pan and scroll active thumbnail into view
@@ -233,7 +235,7 @@ export function Gallery() {
key={currentImage.image_id}
className="gallery-main-img"
src={imgUrl(currentImage)}
alt={currentImage.filename}
alt={captions[currentImage.image_id] || currentImage.filename}
style={{ transform: `translate(${offset.x}px, ${offset.y}px) scale(${zoom})` }}
draggable={false}
/>
@@ -247,7 +249,7 @@ export function Gallery() {
<div className="thumb-strip" ref={thumbStripRef}>
{images.map((img, i) => (
<button key={img.image_id} className={`thumb-btn${i === current ? ' active' : ''}`} onClick={() => setCurrent(i)}>
<img src={thumbUrl(img)} alt={img.filename} loading="lazy" />
<img src={thumbUrl(img)} alt={captions[img.image_id] || img.filename} loading="lazy" />
</button>
))}
</div>

View File

@@ -12,6 +12,8 @@ export function Search() {
const [tag, setTag] = useState('')
const [subject, setSubject] = useState('')
const [personName, setPersonName] = useState('')
const [caption, setCaption] = useState('')
const [label, setLabel] = useState('')
const [results, setResults] = useState<SearchResult[]>([])
const [thumbs, setThumbs] = useState<Record<string, GalleryImage | null>>({})
const [loading, setLoading] = useState(false)
@@ -61,11 +63,13 @@ export function Search() {
tag: tag.trim() || undefined,
subject: subject.trim() || undefined,
person_name: personName.trim() || undefined,
caption: caption.trim() || undefined,
label: label.trim() || undefined,
limit: 24,
})
}
const hasInput = text.trim() || tag.trim() || subject.trim() || personName.trim()
const hasInput = text.trim() || tag.trim() || subject.trim() || personName.trim() || caption.trim() || label.trim()
return (
<div className="page">
@@ -108,6 +112,22 @@ export function Search() {
placeholder={t('search.personNamePlaceholder')}
/>
<label className="field-label">{t('search.captionLabel')}</label>
<input
type="text"
value={caption}
onChange={e => setCaption(e.target.value)}
placeholder={t('search.captionPlaceholder')}
/>
<label className="field-label">{t('search.labelLabel')}</label>
<input
type="text"
value={label}
onChange={e => setLabel(e.target.value)}
placeholder={t('search.labelPlaceholder')}
/>
<button className="btn" type="submit" disabled={loading || !hasInput}>
{t('search.submit')}
</button>