feat(neuron): render the model's chat_template with chat_template_kwargs
Some checks failed
CI / CUDA type-check (push) Failing after 58s
build-prerelease / Resolve version stamps (push) Successful in 39s
CI / Format (push) Successful in 40s
build-prerelease / Build neuron-ampere (push) Failing after 1s
CI / Clippy (push) Successful in 2m37s
build-prerelease / Build cortex binary (push) Successful in 4m47s
CI / Test (push) Successful in 6m13s
build-prerelease / Build neuron-blackwell (push) Failing after 5m34s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m27s
build-prerelease / Build neuron-ada (push) Failing after 7m20s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped
Some checks failed
CI / CUDA type-check (push) Failing after 58s
build-prerelease / Resolve version stamps (push) Successful in 39s
CI / Format (push) Successful in 40s
build-prerelease / Build neuron-ampere (push) Failing after 1s
CI / Clippy (push) Successful in 2m37s
build-prerelease / Build cortex binary (push) Successful in 4m47s
CI / Test (push) Successful in 6m13s
build-prerelease / Build neuron-blackwell (push) Failing after 5m34s
CI / Build cortex SRPM (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Package cortex RPM (push) Successful in 1m27s
build-prerelease / Build neuron-ada (push) Failing after 7m20s
build-prerelease / Package helexa-neuron-ada RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-ampere RPM (push) Has been skipped
build-prerelease / Package helexa-neuron-blackwell RPM (push) Has been skipped
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Has been skipped
Closes #9. Replaces the hardcoded `format_qwen3_prompt` ChatML glue with `minijinja`-driven rendering of the model's own `chat_template` from `tokenizer_config.json`. The request's `chat_template_kwargs` flow into the Jinja context so model-specific levers (Qwen3's `enable_thinking: false`, etc.) actually take effect. ## Implementation - New `harness::chat_template` module with three entry points: - `load_chat_template_alongside(tokenizer_json_path)` — probes `tokenizer_config.json` in the same hf-hub snapshot directory. Supports both the canonical string-form `chat_template` and the array-form some tokenizers ship (multi-template models). - `render_chat_template(template, messages, tools, kwargs)` — renders via `minijinja`. Messages flatten into the `[{role, content}]` shape HF templates iterate, with per-message extras (`tool_calls`, `tool_call_id`) preserved. `tools` and `kwargs` add into the Jinja context so templates that reference them work without us interpreting their shape. - `chat_templates_enabled()` reads `NEURON_USE_CHAT_TEMPLATE` (default true). Falsy values force the fallback path everywhere — a kill switch for emergency rollback without a rebuild. - `LoadedModel.chat_template: Option<String>` and the TP equivalent are populated once at load time. `None` (no tokenizer_config.json, parse error, missing field) routes the fallback path silently; logs go through `tracing::debug`/`warn` per condition. - New `build_prompt_for_request(chat_template, request)` wraps the decision: when both the template is present AND the kill switch is off, render with kwargs from `request.extra` (looks up `chat_template_kwargs` and `tools` lazily). On render error → warn + fallback to `format_qwen3_prompt`. Wired into all four current prompt-build sites (single-GPU stream + non-stream, TP stream + non-stream). ## Dependency `minijinja = "2"` with the `builtins`, `json`, and `serde` features. Pure-Rust Jinja2 implementation, ~80KB compiled. Used internally by HF's `tokenizers-rs` for its own chat templating; the API surface we touch (`Environment::add_template` + `Template::render(serde_value)`) is stable. ## Validation strategy I can't byte-compare the new path's output against `format_qwen3_prompt` for live models without GPU (CI doesn't have one). The fallback path and kill switch are the mitigations — a deploy can flip `NEURON_USE_CHAT_TEMPLATE=false` in the neuron service env if the chat template renders surprisingly on Qwen3-8B in production. The legacy formatter stays the fail-closed default. ## Scope cuts (documented in module header) - Tool-definition lifting from helexa-acp's system-prompt injection into the chat_template's native tools block is deferred. Today the request's `tools` array threads into the Jinja context, but helexa-acp continues to inject Hermes-format tool descriptions into the system prompt for backwards-compat with non-cortex endpoints. ## Tests 9 unit tests in `chat_template`: kill-switch matrix (truthy / falsy / unset), template loading (string form, array form, missing file, unparseable JSON, missing field), rendering (basic conversation threading, kwargs forwarding, message-extras threading for tool_calls). 215 workspace tests pass; clippy + fmt clean across all workspace features (default). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -76,6 +76,15 @@ cudarc = { version = "0.19", optional = true, default-features = false, features
|
||||
half = { version = "2.5", optional = true }
|
||||
tokenizers = { version = "0.22", default-features = false, features = ["onig"] }
|
||||
hf-hub = { version = "0.4", features = ["tokio"] }
|
||||
# Jinja-compatible template renderer for the model's
|
||||
# `tokenizer_config.json::chat_template`. Hugging Face's chat
|
||||
# templates use a strict subset of Jinja2 that minijinja supports
|
||||
# out of the box. ~80KB compiled; pure Rust, no async surface.
|
||||
# Features: `builtins` for the `is defined` / `default` filters HF
|
||||
# templates use; `json` for `tojson` (some Qwen3 templates emit
|
||||
# tool definitions via tojson); `serde` so we can hand it a
|
||||
# serde_json::Value as the context.
|
||||
minijinja = { version = "2", features = ["builtins", "json", "serde"] }
|
||||
# Direct dep on `safetensors` (re-exported by candle but its `TensorView`
|
||||
# / `slice::IndexOp` types are public-but-not-re-exported). Used by the
|
||||
# tp `fused_load` module to read per-rank slices of fused QKV tensors
|
||||
|
||||
@@ -167,6 +167,15 @@ pub struct LoadedModel {
|
||||
/// through as plain text in that case and the consumer parses
|
||||
/// the markers itself if it knows how.
|
||||
pub tool_call_tokens: Option<ToolCallTokenPair>,
|
||||
/// Raw Jinja `chat_template` string loaded from this model's
|
||||
/// `tokenizer_config.json` at load time. `None` when the file
|
||||
/// is absent / unparseable / lacks the field. When `Some`,
|
||||
/// the prompt-build path renders it through `minijinja` with
|
||||
/// `chat_template_kwargs` from the request body; when `None`,
|
||||
/// the hardcoded Qwen3 ChatML fallback (`format_qwen3_prompt`)
|
||||
/// is used. The `NEURON_USE_CHAT_TEMPLATE=false` env var
|
||||
/// forces the fallback path even when `Some`.
|
||||
pub chat_template: Option<String>,
|
||||
}
|
||||
|
||||
impl LoadedModel {
|
||||
@@ -229,6 +238,8 @@ pub struct TpLoadedModel {
|
||||
pub reasoning_tokens: Option<ReasoningTokenPair>,
|
||||
/// Same shape as [`LoadedModel::tool_call_tokens`].
|
||||
pub tool_call_tokens: Option<ToolCallTokenPair>,
|
||||
/// Same shape as [`LoadedModel::chat_template`].
|
||||
pub chat_template: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
@@ -1397,7 +1408,7 @@ impl CandleHarness {
|
||||
let _inference_guard = loaded.inference_lock.lock().await;
|
||||
|
||||
let result = async {
|
||||
let prompt = format_qwen3_prompt(&request.messages);
|
||||
let prompt = build_prompt_for_request(loaded.chat_template.as_deref(), &request);
|
||||
|
||||
let encoding = loaded
|
||||
.tokenizer
|
||||
@@ -1702,7 +1713,7 @@ impl CandleHarness {
|
||||
}
|
||||
};
|
||||
|
||||
let prompt = format_qwen3_prompt(&request.messages);
|
||||
let prompt = build_prompt_for_request(loaded.chat_template.as_deref(), &request);
|
||||
let encoding = loaded
|
||||
.tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
@@ -2081,6 +2092,19 @@ impl Harness for CandleHarness {
|
||||
"tool-call markers detected — streaming will emit structured ToolCall events"
|
||||
);
|
||||
}
|
||||
// Probe `tokenizer_config.json` in the same snapshot dir.
|
||||
// When present and non-empty, the inference path renders
|
||||
// this Jinja template with the request's
|
||||
// `chat_template_kwargs` instead of using the hardcoded
|
||||
// ChatML formatter. Best-effort: missing or unparseable
|
||||
// configs silently fall through to the legacy path.
|
||||
let chat_template = super::chat_template::load_chat_template_alongside(&tokenizer_path);
|
||||
if chat_template.is_some() {
|
||||
tracing::info!(
|
||||
model = %spec.model_id,
|
||||
"chat_template loaded from tokenizer_config.json — prompt assembly will use the model's own template"
|
||||
);
|
||||
}
|
||||
|
||||
let loaded = Arc::new(LoadedModel {
|
||||
model_id: spec.model_id.clone(),
|
||||
@@ -2095,6 +2119,7 @@ impl Harness for CandleHarness {
|
||||
inference_lock: tokio::sync::Mutex::new(()),
|
||||
reasoning_tokens,
|
||||
tool_call_tokens,
|
||||
chat_template,
|
||||
});
|
||||
|
||||
let mut models = self.models.write().await;
|
||||
@@ -2288,6 +2313,13 @@ impl CandleHarness {
|
||||
"TP load: tool-call markers detected"
|
||||
);
|
||||
}
|
||||
let chat_template = super::chat_template::load_chat_template_alongside(&tokenizer_path);
|
||||
if chat_template.is_some() {
|
||||
tracing::info!(
|
||||
model = %spec.model_id,
|
||||
"TP load: chat_template loaded from tokenizer_config.json"
|
||||
);
|
||||
}
|
||||
|
||||
let tp_loaded = StdArc::new(TpLoadedModel {
|
||||
model_id: spec.model_id.clone(),
|
||||
@@ -2303,6 +2335,7 @@ impl CandleHarness {
|
||||
worker: leader_worker,
|
||||
reasoning_tokens,
|
||||
tool_call_tokens,
|
||||
chat_template,
|
||||
});
|
||||
|
||||
let mut models = self.models.write().await;
|
||||
@@ -2429,7 +2462,7 @@ impl CandleHarness {
|
||||
return Err(poisoned_error(&request.model));
|
||||
}
|
||||
|
||||
let prompt = format_qwen3_prompt(&request.messages);
|
||||
let prompt = build_prompt_for_request(tp.chat_template.as_deref(), &request);
|
||||
let encoding = tp
|
||||
.tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
@@ -2893,7 +2926,7 @@ async fn chat_completion_tp_inner(
|
||||
let req_start = std::time::Instant::now();
|
||||
let model_id = request.model.clone();
|
||||
|
||||
let prompt = format_qwen3_prompt(&request.messages);
|
||||
let prompt = build_prompt_for_request(tp.chat_template.as_deref(), &request);
|
||||
let encoding = tp
|
||||
.tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
@@ -3242,6 +3275,66 @@ pub enum InferenceError {
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
/// Build the model's prompt from a [`ChatCompletionRequest`].
|
||||
///
|
||||
/// Prefers the model's own `chat_template` when one was loaded
|
||||
/// from `tokenizer_config.json` at startup and the
|
||||
/// `NEURON_USE_CHAT_TEMPLATE` kill switch isn't tripped. The
|
||||
/// request's `chat_template_kwargs` (e.g.
|
||||
/// `{"enable_thinking": false}` on Qwen3) and `tools` array flow
|
||||
/// into the template's Jinja context so model-specific behaviour
|
||||
/// like reasoning-suppression-at-generation works.
|
||||
///
|
||||
/// Falls back to [`format_qwen3_prompt`] (the legacy hardcoded
|
||||
/// ChatML glue) on any of:
|
||||
///
|
||||
/// - no `chat_template` loaded for this model (older quantised
|
||||
/// variants, fallback-only models)
|
||||
/// - the env kill switch is set to a falsy value
|
||||
/// - the template rendered to an error (caller can flip the env
|
||||
/// var to force fallback while debugging the template)
|
||||
///
|
||||
/// Failures are logged at `warn` so an operator running with
|
||||
/// `RUST_LOG=neuron=debug` sees which path each request took.
|
||||
fn build_prompt_for_request(
|
||||
chat_template: Option<&str>,
|
||||
request: &ChatCompletionRequest,
|
||||
) -> String {
|
||||
if !super::chat_template::chat_templates_enabled() {
|
||||
return format_qwen3_prompt(&request.messages);
|
||||
}
|
||||
let Some(tmpl) = chat_template else {
|
||||
return format_qwen3_prompt(&request.messages);
|
||||
};
|
||||
|
||||
// Pull `chat_template_kwargs` and `tools` from the request's
|
||||
// catch-all `extra` field. Both are optional; absent fields
|
||||
// become `Value::Null`, which the renderer skips inserting
|
||||
// into the Jinja context.
|
||||
let kwargs = request
|
||||
.extra
|
||||
.get("chat_template_kwargs")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
let tools = request
|
||||
.extra
|
||||
.get("tools")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
|
||||
match super::chat_template::render_chat_template(tmpl, &request.messages, &tools, &kwargs) {
|
||||
Ok(prompt) => prompt,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
model = %request.model,
|
||||
error = %format!("{e:#}"),
|
||||
"chat_template render failed; falling back to format_qwen3_prompt"
|
||||
);
|
||||
format_qwen3_prompt(&request.messages)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply the Qwen3 chat template:
|
||||
///
|
||||
/// ```text
|
||||
|
||||
392
crates/neuron/src/harness/chat_template.rs
Normal file
392
crates/neuron/src/harness/chat_template.rs
Normal file
@@ -0,0 +1,392 @@
|
||||
//! Chat-template rendering for the model-supplied Jinja templates
|
||||
//! HuggingFace tokenizers ship in `tokenizer_config.json`.
|
||||
//!
|
||||
//! ## Background
|
||||
//!
|
||||
//! Every modern open-weight model bundles a `chat_template` field
|
||||
//! in its `tokenizer_config.json` — a Jinja2 template string that
|
||||
//! converts a sequence of `{role, content}` messages into the
|
||||
//! exact prompt the model was trained on. Examples:
|
||||
//!
|
||||
//! - Qwen3-Coder: `<|im_start|>{role}\n{content}<|im_end|>\n…`
|
||||
//! with conditional `enable_thinking` handling that injects an
|
||||
//! empty `<think>\n\n</think>` block when set false.
|
||||
//! - DeepSeek-R1: similar im_start framing with different special-
|
||||
//! token names.
|
||||
//! - Mistral / Magistral: a `[INST]` / `[/INST]` framing.
|
||||
//! - Claude / Llama: another shape again.
|
||||
//!
|
||||
//! Rendering the model's own template is the only way to get the
|
||||
//! *exact* prompt format the model was trained on plus the
|
||||
//! model-specific kwargs (`enable_thinking`, `tools`, …) without
|
||||
//! hardcoding per-model logic. The alternative — neuron's previous
|
||||
//! `format_qwen3_prompt` — was a hardcoded Qwen3 ChatML glue that
|
||||
//! ignored kwargs entirely.
|
||||
//!
|
||||
//! ## Scope
|
||||
//!
|
||||
//! This module is request-side only: it builds the prompt string
|
||||
//! the tokenizer ingests before inference. The reasoning- and
|
||||
//! tool-call-marker token routing (issues #6, #8) is response-side
|
||||
//! and stays in `wire::openai_chat` / the streaming inference
|
||||
//! loops.
|
||||
//!
|
||||
//! ## Fallback
|
||||
//!
|
||||
//! When the model's `tokenizer_config.json` is missing, doesn't
|
||||
//! parse, lacks a `chat_template`, or renders an error, the caller
|
||||
//! falls back to `format_qwen3_prompt`. The
|
||||
//! `NEURON_USE_CHAT_TEMPLATE=false` env var is a global kill
|
||||
//! switch — if a deploy goes sideways and the renderer is to
|
||||
//! blame, an operator can flip the env and restart neuron without
|
||||
//! shipping a new build.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use cortex_core::openai::{ChatMessage, MessageContent};
|
||||
use minijinja::Environment;
|
||||
use serde_json::Value;
|
||||
use std::path::Path;
|
||||
|
||||
/// Environment variable that, when set to `false`/`0`/`no`,
|
||||
/// forces every model to skip its `chat_template` and fall back
|
||||
/// to `format_qwen3_prompt`. Default (unset) is "use chat
|
||||
/// templates where available".
|
||||
pub const KILL_SWITCH_ENV: &str = "NEURON_USE_CHAT_TEMPLATE";
|
||||
|
||||
/// Read the global kill switch. `true` means chat templates are
|
||||
/// enabled; `false` forces the fallback path everywhere.
|
||||
pub fn chat_templates_enabled() -> bool {
|
||||
match std::env::var(KILL_SWITCH_ENV).ok().as_deref() {
|
||||
Some(s) => !matches!(
|
||||
s.trim().to_ascii_lowercase().as_str(),
|
||||
"false" | "0" | "no" | "off"
|
||||
),
|
||||
None => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience: probe for `tokenizer_config.json` in the same
|
||||
/// directory the tokenizer was loaded from. Both files come from
|
||||
/// the same HuggingFace snapshot in the hf-hub cache, so the
|
||||
/// sibling path is reliable.
|
||||
pub fn load_chat_template_alongside(tokenizer_json_path: &Path) -> Option<String> {
|
||||
let parent = tokenizer_json_path.parent()?;
|
||||
let config_path = parent.join("tokenizer_config.json");
|
||||
load_chat_template_from(&config_path)
|
||||
}
|
||||
|
||||
/// Best-effort load of `chat_template` from a HuggingFace
|
||||
/// `tokenizer_config.json`. Returns `None` when the file is
|
||||
/// absent, doesn't parse, or lacks the `chat_template` field —
|
||||
/// in all of those cases the caller falls back to
|
||||
/// `format_qwen3_prompt`. Warnings are logged so an operator can
|
||||
/// see why the fallback fired.
|
||||
pub fn load_chat_template_from(path: &Path) -> Option<String> {
|
||||
let text = match std::fs::read_to_string(path) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
path = %path.display(),
|
||||
error = %e,
|
||||
"chat_template: tokenizer_config.json absent or unreadable; falling back"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
let value: Value = match serde_json::from_str(&text) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
path = %path.display(),
|
||||
error = %e,
|
||||
"chat_template: tokenizer_config.json failed to parse; falling back"
|
||||
);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
// Some tokenizer_config.json files carry `chat_template` as an
|
||||
// array of `{name, template}` objects (multi-template models —
|
||||
// tool-use variant, default variant). For now we pick the first
|
||||
// entry; future iterations could honour a name hint.
|
||||
match value.get("chat_template") {
|
||||
Some(Value::String(s)) => Some(s.clone()),
|
||||
Some(Value::Array(arr)) => {
|
||||
for entry in arr {
|
||||
if let Some(t) = entry.get("template").and_then(|v| v.as_str()) {
|
||||
return Some(t.to_string());
|
||||
}
|
||||
}
|
||||
tracing::warn!(
|
||||
path = %path.display(),
|
||||
"chat_template: array form had no usable template entry; falling back"
|
||||
);
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Render the chat template into the prompt the model expects.
|
||||
///
|
||||
/// `template` is the raw Jinja string from `tokenizer_config.json`.
|
||||
/// `messages` is the conversation in order. `kwargs` is the
|
||||
/// `chat_template_kwargs` object the client supplied on the
|
||||
/// request (or `Value::Null` when absent). The function expands
|
||||
/// the kwargs into the Jinja context alongside the standard
|
||||
/// `messages` and `add_generation_prompt` variables HF templates
|
||||
/// expect.
|
||||
///
|
||||
/// `tools` is the request's `tools` array (or `Value::Null`).
|
||||
/// Some chat templates iterate it to emit native tool definitions
|
||||
/// (Qwen3-Coder's tool-use template, Mistral's [TOOL_DEFINITIONS]
|
||||
/// frame). We forward whatever the client sent without
|
||||
/// interpretation.
|
||||
pub fn render_chat_template(
|
||||
template: &str,
|
||||
messages: &[ChatMessage],
|
||||
tools: &Value,
|
||||
kwargs: &Value,
|
||||
) -> Result<String> {
|
||||
let mut env = Environment::new();
|
||||
// Compile the template against a fixed name so error messages
|
||||
// surface "chat_template" rather than `<template>`.
|
||||
env.add_template("chat_template", template)
|
||||
.context("compile chat_template")?;
|
||||
let tmpl = env.get_template("chat_template").unwrap();
|
||||
|
||||
// Convert our internal ChatMessage shape into the
|
||||
// `[{role, content}]` shape HF templates iterate. Text content
|
||||
// becomes a string; Parts becomes an array of content blocks.
|
||||
// The HF templates handle both shapes via `content is string`
|
||||
// checks or content-array iteration.
|
||||
let messages_json: Vec<Value> = messages
|
||||
.iter()
|
||||
.map(|m| {
|
||||
let content_value = match &m.content {
|
||||
MessageContent::Text(s) => Value::String(s.clone()),
|
||||
MessageContent::Parts(parts) => Value::Array(parts.clone()),
|
||||
};
|
||||
let mut obj = serde_json::Map::new();
|
||||
obj.insert("role".into(), Value::String(m.role.clone()));
|
||||
obj.insert("content".into(), content_value);
|
||||
// Forward extras (e.g. tool_calls on assistant turns,
|
||||
// tool_call_id on tool result turns). HF templates that
|
||||
// need them read e.g. `message.tool_calls`.
|
||||
if let Value::Object(extras) = &m.extra {
|
||||
for (k, v) in extras {
|
||||
obj.insert(k.clone(), v.clone());
|
||||
}
|
||||
}
|
||||
Value::Object(obj)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Build the kwargs context. Add base bindings the template
|
||||
// expects (`messages`, `add_generation_prompt`, `tools`) plus
|
||||
// anything the caller passed in `chat_template_kwargs`. Caller
|
||||
// kwargs override the defaults so `add_generation_prompt: false`
|
||||
// from the request actually wins.
|
||||
let mut ctx_map = serde_json::Map::new();
|
||||
ctx_map.insert("messages".into(), Value::Array(messages_json));
|
||||
ctx_map.insert("add_generation_prompt".into(), Value::Bool(true));
|
||||
if !tools.is_null() {
|
||||
ctx_map.insert("tools".into(), tools.clone());
|
||||
}
|
||||
if let Value::Object(kwargs_obj) = kwargs {
|
||||
for (k, v) in kwargs_obj {
|
||||
ctx_map.insert(k.clone(), v.clone());
|
||||
}
|
||||
}
|
||||
// `Template::render` takes any Serialize value; serde_json's
|
||||
// `Value` implements it natively, so we pass the assembled
|
||||
// context object directly without going through the
|
||||
// `context!` macro (which expects minijinja-native values).
|
||||
tmpl.render(Value::Object(ctx_map))
|
||||
.context("render chat_template")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
fn user_msg(text: &str) -> ChatMessage {
|
||||
ChatMessage {
|
||||
role: "user".into(),
|
||||
content: MessageContent::Text(text.into()),
|
||||
extra: Value::Object(Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn assistant_msg(text: &str) -> ChatMessage {
|
||||
ChatMessage {
|
||||
role: "assistant".into(),
|
||||
content: MessageContent::Text(text.into()),
|
||||
extra: Value::Object(Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Minimal Qwen3-style template — enough surface to confirm
|
||||
/// our renderer threads role + content correctly without
|
||||
/// loading a real model's tokenizer_config.json.
|
||||
const QWEN3_LIKE: &str = "{%- for message in messages -%}\
|
||||
<|im_start|>{{ message.role }}\n{{ message.content }}<|im_end|>\n\
|
||||
{%- endfor -%}\
|
||||
{%- if add_generation_prompt -%}<|im_start|>assistant\n{%- endif -%}";
|
||||
|
||||
#[test]
|
||||
fn renders_basic_conversation() {
|
||||
let prompt = render_chat_template(
|
||||
QWEN3_LIKE,
|
||||
&[user_msg("hello"), assistant_msg("hi"), user_msg("bye")],
|
||||
&Value::Null,
|
||||
&Value::Null,
|
||||
)
|
||||
.unwrap();
|
||||
// Structural assertions — the exact whitespace produced
|
||||
// by a given template is a Jinja-trim concern that varies
|
||||
// per real chat_template. What matters is that every
|
||||
// turn's role + content thread through in order, and that
|
||||
// the generation cue lands at the end.
|
||||
assert!(
|
||||
prompt.contains("<|im_start|>user\nhello<|im_end|>"),
|
||||
"first user turn missing: {prompt}"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains("<|im_start|>assistant\nhi<|im_end|>"),
|
||||
"assistant turn missing: {prompt}"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains("<|im_start|>user\nbye<|im_end|>"),
|
||||
"second user turn missing: {prompt}"
|
||||
);
|
||||
assert!(
|
||||
prompt.ends_with("<|im_start|>assistant")
|
||||
|| prompt.ends_with("<|im_start|>assistant\n"),
|
||||
"generation cue missing at end: {prompt}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kwargs_are_threaded_into_template_context() {
|
||||
// Replica of Qwen3's enable_thinking branch in
|
||||
// simplified form. When the kwarg is false, the model's
|
||||
// template injects an empty `<think>...</think>` block
|
||||
// before the generation cue — pre-filling the model's
|
||||
// reasoning slot with "no thinking" so the model emits
|
||||
// the answer directly.
|
||||
let template = "{%- if enable_thinking is defined and enable_thinking is false -%}\
|
||||
NO_THINK\
|
||||
{%- else -%}\
|
||||
THINK_OK\
|
||||
{%- endif -%}";
|
||||
let r_disabled = render_chat_template(
|
||||
template,
|
||||
&[],
|
||||
&Value::Null,
|
||||
&json!({ "enable_thinking": false }),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(r_disabled, "NO_THINK");
|
||||
let r_default = render_chat_template(template, &[], &Value::Null, &Value::Null).unwrap();
|
||||
assert_eq!(r_default, "THINK_OK");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_template_field_returns_none() {
|
||||
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-missing-field.json");
|
||||
std::fs::write(&tmp, r#"{"some_other_field": 1}"#).unwrap();
|
||||
assert!(load_chat_template_from(&tmp).is_none());
|
||||
let _ = std::fs::remove_file(tmp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_template_from_string_field() {
|
||||
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-string.json");
|
||||
std::fs::write(
|
||||
&tmp,
|
||||
r#"{"chat_template": "hello {{ messages[0].content }}"}"#,
|
||||
)
|
||||
.unwrap();
|
||||
let t = load_chat_template_from(&tmp).expect("template loaded");
|
||||
assert!(t.contains("messages[0].content"));
|
||||
let _ = std::fs::remove_file(tmp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_template_from_array_form() {
|
||||
// Some HF models ship `chat_template` as `[{name, template}, ...]`.
|
||||
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-array.json");
|
||||
std::fs::write(
|
||||
&tmp,
|
||||
r#"{"chat_template": [{"name": "default", "template": "ARR"}]}"#,
|
||||
)
|
||||
.unwrap();
|
||||
let t = load_chat_template_from(&tmp).expect("template loaded");
|
||||
assert_eq!(t, "ARR");
|
||||
let _ = std::fs::remove_file(tmp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_file_returns_none_quietly() {
|
||||
let absent = std::path::PathBuf::from("/definitely/not/a/real/path.json");
|
||||
assert!(load_chat_template_from(&absent).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unparseable_returns_none() {
|
||||
let tmp = std::env::temp_dir().join("neuron-test-tokenizer-garbage.json");
|
||||
std::fs::write(&tmp, b"{not valid json").unwrap();
|
||||
assert!(load_chat_template_from(&tmp).is_none());
|
||||
let _ = std::fs::remove_file(tmp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kill_switch_recognises_truthy_falsy_values() {
|
||||
// Test against the actual env var so callers see the
|
||||
// same behaviour as production. Serialise via a
|
||||
// mutex — see path_util.rs for the pattern.
|
||||
use std::sync::Mutex;
|
||||
static LOCK: Mutex<()> = Mutex::new(());
|
||||
let _g = LOCK.lock().unwrap();
|
||||
let prior = std::env::var(KILL_SWITCH_ENV).ok();
|
||||
unsafe {
|
||||
std::env::remove_var(KILL_SWITCH_ENV);
|
||||
}
|
||||
assert!(chat_templates_enabled());
|
||||
for value in ["false", "0", "no", "off", "FALSE", " no "] {
|
||||
unsafe { std::env::set_var(KILL_SWITCH_ENV, value) };
|
||||
assert!(!chat_templates_enabled(), "value {value:?} should disable");
|
||||
}
|
||||
for value in ["true", "1", "yes", ""] {
|
||||
unsafe { std::env::set_var(KILL_SWITCH_ENV, value) };
|
||||
assert!(chat_templates_enabled(), "value {value:?} should enable");
|
||||
}
|
||||
unsafe {
|
||||
match prior {
|
||||
Some(p) => std::env::set_var(KILL_SWITCH_ENV, p),
|
||||
None => std::env::remove_var(KILL_SWITCH_ENV),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn message_extras_thread_through_for_tool_calls() {
|
||||
// HF templates read assistant.tool_calls and tool
|
||||
// turns' tool_call_id. Confirm our extras flatten into
|
||||
// the message object the template iterates.
|
||||
let mut extras = serde_json::Map::new();
|
||||
extras.insert(
|
||||
"tool_calls".into(),
|
||||
json!([{"id": "t1", "function": {"name": "x", "arguments": "{}"}}]),
|
||||
);
|
||||
let msg = ChatMessage {
|
||||
role: "assistant".into(),
|
||||
content: MessageContent::Text(String::new()),
|
||||
extra: Value::Object(extras),
|
||||
};
|
||||
let template = "{{ messages[0].tool_calls[0].id }}";
|
||||
let rendered = render_chat_template(template, &[msg], &Value::Null, &Value::Null).unwrap();
|
||||
assert_eq!(rendered, "t1");
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
pub mod arch;
|
||||
pub mod candle;
|
||||
pub mod chat_template;
|
||||
pub mod device_worker;
|
||||
pub mod tp;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user