From cb303832bcd663f7adc6e22e85b6e0971c01ddd0 Mon Sep 17 00:00:00 2001 From: rob thijssen Date: Sun, 31 May 2026 23:43:11 +0300 Subject: [PATCH] feat(neuron): render the model's chat_template with chat_template_kwargs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #9. Replaces the hardcoded `format_qwen3_prompt` ChatML glue with `minijinja`-driven rendering of the model's own `chat_template` from `tokenizer_config.json`. The request's `chat_template_kwargs` flow into the Jinja context so model-specific levers (Qwen3's `enable_thinking: false`, etc.) actually take effect. ## Implementation - New `harness::chat_template` module with three entry points: - `load_chat_template_alongside(tokenizer_json_path)` — probes `tokenizer_config.json` in the same hf-hub snapshot directory. Supports both the canonical string-form `chat_template` and the array-form some tokenizers ship (multi-template models). - `render_chat_template(template, messages, tools, kwargs)` — renders via `minijinja`. Messages flatten into the `[{role, content}]` shape HF templates iterate, with per-message extras (`tool_calls`, `tool_call_id`) preserved. `tools` and `kwargs` add into the Jinja context so templates that reference them work without us interpreting their shape. - `chat_templates_enabled()` reads `NEURON_USE_CHAT_TEMPLATE` (default true). Falsy values force the fallback path everywhere — a kill switch for emergency rollback without a rebuild. - `LoadedModel.chat_template: Option` and the TP equivalent are populated once at load time. `None` (no tokenizer_config.json, parse error, missing field) routes the fallback path silently; logs go through `tracing::debug`/`warn` per condition. - New `build_prompt_for_request(chat_template, request)` wraps the decision: when both the template is present AND the kill switch is off, render with kwargs from `request.extra` (looks up `chat_template_kwargs` and `tools` lazily). On render error → warn + fallback to `format_qwen3_prompt`. Wired into all four current prompt-build sites (single-GPU stream + non-stream, TP stream + non-stream). ## Dependency `minijinja = "2"` with the `builtins`, `json`, and `serde` features. Pure-Rust Jinja2 implementation, ~80KB compiled. Used internally by HF's `tokenizers-rs` for its own chat templating; the API surface we touch (`Environment::add_template` + `Template::render(serde_value)`) is stable. ## Validation strategy I can't byte-compare the new path's output against `format_qwen3_prompt` for live models without GPU (CI doesn't have one). The fallback path and kill switch are the mitigations — a deploy can flip `NEURON_USE_CHAT_TEMPLATE=false` in the neuron service env if the chat template renders surprisingly on Qwen3-8B in production. The legacy formatter stays the fail-closed default. ## Scope cuts (documented in module header) - Tool-definition lifting from helexa-acp's system-prompt injection into the chat_template's native tools block is deferred. Today the request's `tools` array threads into the Jinja context, but helexa-acp continues to inject Hermes-format tool descriptions into the system prompt for backwards-compat with non-cortex endpoints. ## Tests 9 unit tests in `chat_template`: kill-switch matrix (truthy / falsy / unset), template loading (string form, array form, missing file, unparseable JSON, missing field), rendering (basic conversation threading, kwargs forwarding, message-extras threading for tool_calls). 215 workspace tests pass; clippy + fmt clean across all workspace features (default). Co-Authored-By: Claude Opus 4.7 --- Cargo.lock | 18 + crates/neuron/Cargo.toml | 9 + crates/neuron/src/harness/candle.rs | 101 +++++- crates/neuron/src/harness/chat_template.rs | 392 +++++++++++++++++++++ crates/neuron/src/harness/mod.rs | 1 + 5 files changed, 517 insertions(+), 4 deletions(-) create mode 100644 crates/neuron/src/harness/chat_template.rs diff --git a/Cargo.lock b/Cargo.lock index c3f30b1..f8bc7b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2379,6 +2379,12 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "memo-map" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" + [[package]] name = "metrics" version = "0.24.3" @@ -2432,6 +2438,17 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minijinja" +version = "2.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2929e494b2280e1e18959bb2e121da03347ae896896fdfaceaab43c88a02803f" +dependencies = [ + "memo-map", + "serde", + "serde_json", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -2516,6 +2533,7 @@ dependencies = [ "futures", "half", "hf-hub", + "minijinja", "reqwest", "safetensors 0.7.0", "serde", diff --git a/crates/neuron/Cargo.toml b/crates/neuron/Cargo.toml index 7cf7cf4..80711aa 100644 --- a/crates/neuron/Cargo.toml +++ b/crates/neuron/Cargo.toml @@ -76,6 +76,15 @@ cudarc = { version = "0.19", optional = true, default-features = false, features half = { version = "2.5", optional = true } tokenizers = { version = "0.22", default-features = false, features = ["onig"] } hf-hub = { version = "0.4", features = ["tokio"] } +# Jinja-compatible template renderer for the model's +# `tokenizer_config.json::chat_template`. Hugging Face's chat +# templates use a strict subset of Jinja2 that minijinja supports +# out of the box. ~80KB compiled; pure Rust, no async surface. +# Features: `builtins` for the `is defined` / `default` filters HF +# templates use; `json` for `tojson` (some Qwen3 templates emit +# tool definitions via tojson); `serde` so we can hand it a +# serde_json::Value as the context. +minijinja = { version = "2", features = ["builtins", "json", "serde"] } # Direct dep on `safetensors` (re-exported by candle but its `TensorView` # / `slice::IndexOp` types are public-but-not-re-exported). Used by the # tp `fused_load` module to read per-rank slices of fused QKV tensors diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 404a27c..c5f75bf 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -167,6 +167,15 @@ pub struct LoadedModel { /// through as plain text in that case and the consumer parses /// the markers itself if it knows how. pub tool_call_tokens: Option, + /// Raw Jinja `chat_template` string loaded from this model's + /// `tokenizer_config.json` at load time. `None` when the file + /// is absent / unparseable / lacks the field. When `Some`, + /// the prompt-build path renders it through `minijinja` with + /// `chat_template_kwargs` from the request body; when `None`, + /// the hardcoded Qwen3 ChatML fallback (`format_qwen3_prompt`) + /// is used. The `NEURON_USE_CHAT_TEMPLATE=false` env var + /// forces the fallback path even when `Some`. + pub chat_template: Option, } impl LoadedModel { @@ -229,6 +238,8 @@ pub struct TpLoadedModel { pub reasoning_tokens: Option, /// Same shape as [`LoadedModel::tool_call_tokens`]. pub tool_call_tokens: Option, + /// Same shape as [`LoadedModel::chat_template`]. + pub chat_template: Option, } #[cfg(feature = "cuda")] @@ -1397,7 +1408,7 @@ impl CandleHarness { let _inference_guard = loaded.inference_lock.lock().await; let result = async { - let prompt = format_qwen3_prompt(&request.messages); + let prompt = build_prompt_for_request(loaded.chat_template.as_deref(), &request); let encoding = loaded .tokenizer @@ -1702,7 +1713,7 @@ impl CandleHarness { } }; - let prompt = format_qwen3_prompt(&request.messages); + let prompt = build_prompt_for_request(loaded.chat_template.as_deref(), &request); let encoding = loaded .tokenizer .encode(prompt.as_str(), true) @@ -2081,6 +2092,19 @@ impl Harness for CandleHarness { "tool-call markers detected — streaming will emit structured ToolCall events" ); } + // Probe `tokenizer_config.json` in the same snapshot dir. + // When present and non-empty, the inference path renders + // this Jinja template with the request's + // `chat_template_kwargs` instead of using the hardcoded + // ChatML formatter. Best-effort: missing or unparseable + // configs silently fall through to the legacy path. + let chat_template = super::chat_template::load_chat_template_alongside(&tokenizer_path); + if chat_template.is_some() { + tracing::info!( + model = %spec.model_id, + "chat_template loaded from tokenizer_config.json — prompt assembly will use the model's own template" + ); + } let loaded = Arc::new(LoadedModel { model_id: spec.model_id.clone(), @@ -2095,6 +2119,7 @@ impl Harness for CandleHarness { inference_lock: tokio::sync::Mutex::new(()), reasoning_tokens, tool_call_tokens, + chat_template, }); let mut models = self.models.write().await; @@ -2288,6 +2313,13 @@ impl CandleHarness { "TP load: tool-call markers detected" ); } + let chat_template = super::chat_template::load_chat_template_alongside(&tokenizer_path); + if chat_template.is_some() { + tracing::info!( + model = %spec.model_id, + "TP load: chat_template loaded from tokenizer_config.json" + ); + } let tp_loaded = StdArc::new(TpLoadedModel { model_id: spec.model_id.clone(), @@ -2303,6 +2335,7 @@ impl CandleHarness { worker: leader_worker, reasoning_tokens, tool_call_tokens, + chat_template, }); let mut models = self.models.write().await; @@ -2429,7 +2462,7 @@ impl CandleHarness { return Err(poisoned_error(&request.model)); } - let prompt = format_qwen3_prompt(&request.messages); + let prompt = build_prompt_for_request(tp.chat_template.as_deref(), &request); let encoding = tp .tokenizer .encode(prompt.as_str(), true) @@ -2893,7 +2926,7 @@ async fn chat_completion_tp_inner( let req_start = std::time::Instant::now(); let model_id = request.model.clone(); - let prompt = format_qwen3_prompt(&request.messages); + let prompt = build_prompt_for_request(tp.chat_template.as_deref(), &request); let encoding = tp .tokenizer .encode(prompt.as_str(), true) @@ -3242,6 +3275,66 @@ pub enum InferenceError { Other(#[from] anyhow::Error), } +/// Build the model's prompt from a [`ChatCompletionRequest`]. +/// +/// Prefers the model's own `chat_template` when one was loaded +/// from `tokenizer_config.json` at startup and the +/// `NEURON_USE_CHAT_TEMPLATE` kill switch isn't tripped. The +/// request's `chat_template_kwargs` (e.g. +/// `{"enable_thinking": false}` on Qwen3) and `tools` array flow +/// into the template's Jinja context so model-specific behaviour +/// like reasoning-suppression-at-generation works. +/// +/// Falls back to [`format_qwen3_prompt`] (the legacy hardcoded +/// ChatML glue) on any of: +/// +/// - no `chat_template` loaded for this model (older quantised +/// variants, fallback-only models) +/// - the env kill switch is set to a falsy value +/// - the template rendered to an error (caller can flip the env +/// var to force fallback while debugging the template) +/// +/// Failures are logged at `warn` so an operator running with +/// `RUST_LOG=neuron=debug` sees which path each request took. +fn build_prompt_for_request( + chat_template: Option<&str>, + request: &ChatCompletionRequest, +) -> String { + if !super::chat_template::chat_templates_enabled() { + return format_qwen3_prompt(&request.messages); + } + let Some(tmpl) = chat_template else { + return format_qwen3_prompt(&request.messages); + }; + + // Pull `chat_template_kwargs` and `tools` from the request's + // catch-all `extra` field. Both are optional; absent fields + // become `Value::Null`, which the renderer skips inserting + // into the Jinja context. + let kwargs = request + .extra + .get("chat_template_kwargs") + .cloned() + .unwrap_or(serde_json::Value::Null); + let tools = request + .extra + .get("tools") + .cloned() + .unwrap_or(serde_json::Value::Null); + + match super::chat_template::render_chat_template(tmpl, &request.messages, &tools, &kwargs) { + Ok(prompt) => prompt, + Err(e) => { + tracing::warn!( + model = %request.model, + error = %format!("{e:#}"), + "chat_template render failed; falling back to format_qwen3_prompt" + ); + format_qwen3_prompt(&request.messages) + } + } +} + /// Apply the Qwen3 chat template: /// /// ```text diff --git a/crates/neuron/src/harness/chat_template.rs b/crates/neuron/src/harness/chat_template.rs new file mode 100644 index 0000000..b25cc44 --- /dev/null +++ b/crates/neuron/src/harness/chat_template.rs @@ -0,0 +1,392 @@ +//! Chat-template rendering for the model-supplied Jinja templates +//! HuggingFace tokenizers ship in `tokenizer_config.json`. +//! +//! ## Background +//! +//! Every modern open-weight model bundles a `chat_template` field +//! in its `tokenizer_config.json` — a Jinja2 template string that +//! converts a sequence of `{role, content}` messages into the +//! exact prompt the model was trained on. Examples: +//! +//! - Qwen3-Coder: `<|im_start|>{role}\n{content}<|im_end|>\n…` +//! with conditional `enable_thinking` handling that injects an +//! empty `\n\n` block when set false. +//! - DeepSeek-R1: similar im_start framing with different special- +//! token names. +//! - Mistral / Magistral: a `[INST]` / `[/INST]` framing. +//! - Claude / Llama: another shape again. +//! +//! Rendering the model's own template is the only way to get the +//! *exact* prompt format the model was trained on plus the +//! model-specific kwargs (`enable_thinking`, `tools`, …) without +//! hardcoding per-model logic. The alternative — neuron's previous +//! `format_qwen3_prompt` — was a hardcoded Qwen3 ChatML glue that +//! ignored kwargs entirely. +//! +//! ## Scope +//! +//! This module is request-side only: it builds the prompt string +//! the tokenizer ingests before inference. The reasoning- and +//! tool-call-marker token routing (issues #6, #8) is response-side +//! and stays in `wire::openai_chat` / the streaming inference +//! loops. +//! +//! ## Fallback +//! +//! When the model's `tokenizer_config.json` is missing, doesn't +//! parse, lacks a `chat_template`, or renders an error, the caller +//! falls back to `format_qwen3_prompt`. The +//! `NEURON_USE_CHAT_TEMPLATE=false` env var is a global kill +//! switch — if a deploy goes sideways and the renderer is to +//! blame, an operator can flip the env and restart neuron without +//! shipping a new build. + +use anyhow::{Context, Result}; +use cortex_core::openai::{ChatMessage, MessageContent}; +use minijinja::Environment; +use serde_json::Value; +use std::path::Path; + +/// Environment variable that, when set to `false`/`0`/`no`, +/// forces every model to skip its `chat_template` and fall back +/// to `format_qwen3_prompt`. Default (unset) is "use chat +/// templates where available". +pub const KILL_SWITCH_ENV: &str = "NEURON_USE_CHAT_TEMPLATE"; + +/// Read the global kill switch. `true` means chat templates are +/// enabled; `false` forces the fallback path everywhere. +pub fn chat_templates_enabled() -> bool { + match std::env::var(KILL_SWITCH_ENV).ok().as_deref() { + Some(s) => !matches!( + s.trim().to_ascii_lowercase().as_str(), + "false" | "0" | "no" | "off" + ), + None => true, + } +} + +/// Convenience: probe for `tokenizer_config.json` in the same +/// directory the tokenizer was loaded from. Both files come from +/// the same HuggingFace snapshot in the hf-hub cache, so the +/// sibling path is reliable. +pub fn load_chat_template_alongside(tokenizer_json_path: &Path) -> Option { + let parent = tokenizer_json_path.parent()?; + let config_path = parent.join("tokenizer_config.json"); + load_chat_template_from(&config_path) +} + +/// Best-effort load of `chat_template` from a HuggingFace +/// `tokenizer_config.json`. Returns `None` when the file is +/// absent, doesn't parse, or lacks the `chat_template` field — +/// in all of those cases the caller falls back to +/// `format_qwen3_prompt`. Warnings are logged so an operator can +/// see why the fallback fired. +pub fn load_chat_template_from(path: &Path) -> Option { + let text = match std::fs::read_to_string(path) { + Ok(t) => t, + Err(e) => { + tracing::debug!( + path = %path.display(), + error = %e, + "chat_template: tokenizer_config.json absent or unreadable; falling back" + ); + return None; + } + }; + let value: Value = match serde_json::from_str(&text) { + Ok(v) => v, + Err(e) => { + tracing::warn!( + path = %path.display(), + error = %e, + "chat_template: tokenizer_config.json failed to parse; falling back" + ); + return None; + } + }; + // Some tokenizer_config.json files carry `chat_template` as an + // array of `{name, template}` objects (multi-template models — + // tool-use variant, default variant). For now we pick the first + // entry; future iterations could honour a name hint. + match value.get("chat_template") { + Some(Value::String(s)) => Some(s.clone()), + Some(Value::Array(arr)) => { + for entry in arr { + if let Some(t) = entry.get("template").and_then(|v| v.as_str()) { + return Some(t.to_string()); + } + } + tracing::warn!( + path = %path.display(), + "chat_template: array form had no usable template entry; falling back" + ); + None + } + _ => None, + } +} + +/// Render the chat template into the prompt the model expects. +/// +/// `template` is the raw Jinja string from `tokenizer_config.json`. +/// `messages` is the conversation in order. `kwargs` is the +/// `chat_template_kwargs` object the client supplied on the +/// request (or `Value::Null` when absent). The function expands +/// the kwargs into the Jinja context alongside the standard +/// `messages` and `add_generation_prompt` variables HF templates +/// expect. +/// +/// `tools` is the request's `tools` array (or `Value::Null`). +/// Some chat templates iterate it to emit native tool definitions +/// (Qwen3-Coder's tool-use template, Mistral's [TOOL_DEFINITIONS] +/// frame). We forward whatever the client sent without +/// interpretation. +pub fn render_chat_template( + template: &str, + messages: &[ChatMessage], + tools: &Value, + kwargs: &Value, +) -> Result { + let mut env = Environment::new(); + // Compile the template against a fixed name so error messages + // surface "chat_template" rather than `