diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index ac12e95..ea860ac 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -128,6 +128,62 @@ const REPEAT_PENALTY: f32 = 1.1; /// penalty. Matches the candle quantized-qwen3 example default. const REPEAT_LAST_N: usize = 64; +/// Architectures the dense safetensors path can construct. Keep +/// alphabetical; one entry per supported `config.json#/model_type` +/// value. New entries land alongside a new `ModelArch` variant + a new +/// dispatch branch in `load_arch_dense` / `run_inference` / +/// `run_inference_streaming` (plus, for TP, a parallel pattern in +/// `tp_qwen3.rs`). +const DENSE_SUPPORTED_MODEL_TYPES: &[&str] = &["qwen3"]; + +/// Pre-flight check the operator's `config.json` against the set of +/// architectures the dense path actually knows how to build. Surfaces +/// architecture mismatches as a single clean error before the serde +/// deserializer trips on missing fields — the latter happens because +/// every architecture has different hyperparameter names, so when the +/// JSON is e.g. Qwen3.6 wrapped under `text_config: {...}`, candle's +/// `qwen3::Config` finds none of its expected top-level fields and +/// fails with a cryptic `missing field 'vocab_size' at line N col 1`. +/// +/// The result message names the model_type we saw, the supported set, +/// and points at the files an operator (or future contributor) needs +/// to touch to grow the supported set. +fn check_dense_config_supported(config_json: &str, model_id: &str) -> Result<()> { + let v: serde_json::Value = serde_json::from_str(config_json) + .with_context(|| format!("parse config.json for '{model_id}' as JSON"))?; + let model_type = v.get("model_type").and_then(|x| x.as_str()).unwrap_or(""); + if model_type.is_empty() { + anyhow::bail!( + "config.json for '{model_id}' is missing `model_type`; the dense \ + path needs it to gate architecture support (supported: {:?})", + DENSE_SUPPORTED_MODEL_TYPES + ); + } + if DENSE_SUPPORTED_MODEL_TYPES.contains(&model_type) { + return Ok(()); + } + // Bonus context: the model usually also lists architectures, which + // is what `transformers` keys on. Including it makes the error + // self-contained. + let architectures = v + .get("architectures") + .and_then(|x| x.as_array()) + .map(|a| { + a.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect::>() + }) + .unwrap_or_default(); + anyhow::bail!( + "unsupported model_type '{model_type}' for '{model_id}' \ + (architectures={architectures:?}); the dense path supports {:?}. \ + Add a `ModelArch` variant + load/forward branches in \ + crates/neuron/src/harness/candle.rs (and the TP analogue in \ + tp_qwen3.rs) to extend coverage.", + DENSE_SUPPORTED_MODEL_TYPES + ); +} + /// Resolve the effective HuggingFace cache directory for the candle /// harness. Precedence (first hit wins): /// @@ -346,6 +402,7 @@ impl CandleHarness { "loading dense Qwen3 from safetensors" ); let cfg_text = std::fs::read_to_string(&config_path).context("read config.json")?; + check_dense_config_supported(&cfg_text, &model_id_for_log)?; let cfg: qwen3_dense::Config = serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?; @@ -820,6 +877,11 @@ impl CandleHarness { let (config_path, tokenizer_path, safetensors_paths) = self.resolve_dense_files(spec).await?; let config_json = std::fs::read_to_string(&config_path).context("read config.json")?; + // Reject unsupported architectures *before* spawning the worker + // pool and fanning out NCCL — otherwise we'd burn the pool + // lifecycle on a load that's guaranteed to fail at deserialise + // time inside every rank. + check_dense_config_supported(&config_json, &spec.model_id)?; // 2. Spawn the worker pool. Rank 0 stays in-process; ranks // 1..tp_size are subprocesses, one per device after the @@ -1518,3 +1580,64 @@ fn unix_subsec_nanos() -> u64 { .map(|d| d.as_nanos() as u64) .unwrap_or(0) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn check_dense_config_accepts_qwen3() { + let cfg = r#"{ + "model_type": "qwen3", + "vocab_size": 151936, + "architectures": ["Qwen3ForCausalLM"] + }"#; + check_dense_config_supported(cfg, "Qwen/Qwen3-1.7B").expect("qwen3 should pass"); + } + + #[test] + fn check_dense_config_rejects_qwen3_5_with_clear_message() { + let cfg = r#"{ + "model_type": "qwen3_5", + "architectures": ["Qwen3_5ForConditionalGeneration"], + "image_token_id": 248056, + "text_config": {"hidden_size": 5120} + }"#; + let err = check_dense_config_supported(cfg, "Qwen/Qwen3.6-27B") + .expect_err("qwen3_5 should be rejected"); + let msg = format!("{err}"); + assert!( + msg.contains("unsupported model_type 'qwen3_5'"), + "message should name the rejected type: {msg}" + ); + assert!( + msg.contains("Qwen/Qwen3.6-27B"), + "message should echo the model id: {msg}" + ); + assert!( + msg.contains("qwen3"), + "message should list the supported set: {msg}" + ); + } + + #[test] + fn check_dense_config_rejects_missing_model_type() { + let cfg = r#"{ "vocab_size": 1234 }"#; + let err = check_dense_config_supported(cfg, "anon/no-type") + .expect_err("missing model_type should be rejected"); + assert!( + format!("{err}").contains("missing `model_type`"), + "message should call out the missing field" + ); + } + + #[test] + fn check_dense_config_rejects_invalid_json() { + let err = check_dense_config_supported("not json", "anon/bad-json") + .expect_err("malformed JSON should be rejected"); + assert!( + format!("{err:#}").contains("config.json"), + "message should mention config.json" + ); + } +}