diff --git a/crates/neuron/src/harness/candle.rs b/crates/neuron/src/harness/candle.rs index 3dc8522..ff8dc3e 100644 --- a/crates/neuron/src/harness/candle.rs +++ b/crates/neuron/src/harness/candle.rs @@ -11,9 +11,11 @@ use anyhow::{Context, Result}; use async_trait::async_trait; use candle_core::quantized::gguf_file; -use candle_core::{Device, Tensor}; +use candle_core::{DType, Device, Tensor}; +use candle_nn::VarBuilder; use candle_transformers::generation::{LogitsProcessor, Sampling}; use candle_transformers::models::quantized_qwen3::ModelWeights as QuantizedQwen3Weights; +use candle_transformers::models::qwen3 as qwen3_dense; use cortex_core::harness::{Harness, HarnessHealth, ModelInfo, ModelSpec}; use cortex_core::openai::{ ChatCompletionChoice, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, @@ -46,11 +48,21 @@ pub struct LoadedModel { pub devices: Vec, } -/// Architecture-specific weights. Stage 3 still supports only Qwen3 -/// quantized; Stage 8 broadens this to additional families and -/// non-quantized variants. +/// Architecture-specific weights. +/// +/// - `Qwen3Quantized` — GGUF source, pre-quantized. Single-GPU only; +/// TP sharding pre-quantized super-blocks is intractable. Stays the +/// default for small models loaded via `Qwen/Qwen3-*-GGUF` and +/// `unsloth/Qwen3-*-GGUF` repos. +/// - `Qwen3Dense` — bf16 safetensors source. The path that supports +/// TP (Stage 7b-ii+) because slicing dense weights by row/column +/// under safetensors is mechanical. Used when `ModelSpec.quant` is +/// None; intended target for Qwen3.6-27B etc. +/// +/// Stage 8 broadens this to additional families. pub enum ModelArch { Qwen3Quantized(QuantizedQwen3Weights), + Qwen3Dense(qwen3_dense::ModelForCausalLM), } /// Repetition penalty applied to recently-generated tokens before @@ -108,14 +120,172 @@ impl CandleHarness { Ok(Device::Cpu) } - /// Resolve a model spec to local GGUF and tokenizer file paths via - /// hf-hub. Downloads on first use; subsequent calls are cached. - async fn resolve_files(&self, spec: &ModelSpec) -> Result<(PathBuf, PathBuf)> { + /// Build an hf-hub API client pre-configured with the harness's + /// `hf_cache` (when one is set). + fn hf_api(&self) -> Result { let mut builder = hf_hub::api::tokio::ApiBuilder::new(); if let Some(cache) = &self.hf_cache { builder = builder.with_cache_dir(cache.clone()); } - let api = builder.build().context("build hf-hub API")?; + builder.build().context("build hf-hub API") + } + + /// Resolve a dense (bf16/fp16 safetensors) model to its local file + /// paths. + /// + /// Handles both sharded repos (`model.safetensors.index.json` plus + /// several `model-*.safetensors`) and the single-file layout + /// (`model.safetensors`). Returns the safetensors paths in + /// arbitrary order — `VarBuilder` unifies them into one tensor view. + async fn resolve_dense_files( + &self, + spec: &ModelSpec, + ) -> Result<(PathBuf, PathBuf, Vec)> { + let api = self.hf_api()?; + let repo = api.model(spec.model_id.clone()); + + let config_path = repo + .get("config.json") + .await + .with_context(|| format!("fetch config.json from {}", spec.model_id))?; + let tokenizer_path = repo + .get("tokenizer.json") + .await + .with_context(|| format!("fetch tokenizer.json from {}", spec.model_id))?; + + // Prefer the sharded layout (most HF dense models > 5B ship it). + let safetensors_paths = match repo.get("model.safetensors.index.json").await { + Ok(index_path) => { + let index_text = std::fs::read_to_string(&index_path) + .context("read model.safetensors.index.json")?; + let index: serde_json::Value = serde_json::from_str(&index_text) + .context("parse model.safetensors.index.json")?; + let weight_map = index + .get("weight_map") + .and_then(|v| v.as_object()) + .ok_or_else(|| { + anyhow::anyhow!("safetensors index missing weight_map object") + })?; + let unique: std::collections::BTreeSet = weight_map + .values() + .filter_map(|v| v.as_str().map(String::from)) + .collect(); + let mut paths = Vec::with_capacity(unique.len()); + for fname in unique { + let p = repo + .get(&fname) + .await + .with_context(|| format!("fetch sharded safetensors {fname}"))?; + paths.push(p); + } + paths + } + Err(_) => { + // Single-file fallback. + let p = repo + .get("model.safetensors") + .await + .context("fetch model.safetensors (single-file layout)")?; + vec![p] + } + }; + Ok((config_path, tokenizer_path, safetensors_paths)) + } + + /// Resolve + load a GGUF (pre-quantized) Qwen3. Returns the + /// tokenizer.json path so the caller can construct the Tokenizer + /// uniformly across source formats. + async fn load_arch_gguf( + &self, + spec: &ModelSpec, + device: &Device, + ) -> Result<(PathBuf, ModelArch)> { + let (gguf_path, tokenizer_path) = self.resolve_files(spec).await?; + let device_for_load = device.clone(); + let gguf_path_for_load = gguf_path.clone(); + let model_id_for_log = spec.model_id.clone(); + let arch = tokio::task::spawn_blocking(move || -> Result { + tracing::info!(model = %model_id_for_log, path = ?gguf_path_for_load, "loading GGUF"); + let mut file = std::fs::File::open(&gguf_path_for_load).context("open GGUF file")?; + let content = gguf_file::Content::read(&mut file) + .map_err(|e| anyhow::anyhow!("parse GGUF: {e}"))?; + + let architecture = content + .metadata + .get("general.architecture") + .and_then(|v| v.to_string().ok().cloned()) + .unwrap_or_default(); + tracing::info!(architecture = %architecture, "GGUF architecture"); + + match architecture.as_str() { + "qwen3" => { + let weights = + QuantizedQwen3Weights::from_gguf(content, &mut file, &device_for_load) + .map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?; + Ok(ModelArch::Qwen3Quantized(weights)) + } + other => anyhow::bail!( + "unsupported GGUF architecture '{other}'; quantized path only supports qwen3" + ), + } + }) + .await + .context("blocking GGUF load task panicked")??; + Ok((tokenizer_path, arch)) + } + + /// Resolve + load a dense Qwen3 from safetensors. Uses + /// `candle-transformers::models::qwen3::ModelForCausalLM` and + /// builds a VarBuilder over the mmap'd safetensors files. dtype + /// is bf16 by default to match the HF distribution dtype for + /// recent Qwen3 family models; fall back to f16 if the device + /// doesn't support bf16. + async fn load_arch_dense( + &self, + spec: &ModelSpec, + device: &Device, + ) -> Result<(PathBuf, ModelArch)> { + let (config_path, tokenizer_path, safetensors_paths) = + self.resolve_dense_files(spec).await?; + let device_for_load = device.clone(); + let model_id_for_log = spec.model_id.clone(); + + let arch = tokio::task::spawn_blocking(move || -> Result { + tracing::info!( + model = %model_id_for_log, + shards = safetensors_paths.len(), + "loading dense Qwen3 from safetensors" + ); + let cfg_text = std::fs::read_to_string(&config_path).context("read config.json")?; + let cfg: qwen3_dense::Config = + serde_json::from_str(&cfg_text).context("parse Qwen3 config.json")?; + + // bf16 is the canonical Qwen3 distribution dtype. CUDA + // devices on Ada+ support it; Ampere also supports bf16 + // natively. CPU candle handles bf16 via emulation. + let dtype = DType::BF16; + // SAFETY: VarBuilder::from_mmaped_safetensors mmaps the files; + // mutation of the underlying files by another process while + // we hold the mapping is UB. We trust that nothing else on + // the host modifies the HF cache files during a model's + // lifetime (hf-hub itself is immutable-by-design). + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&safetensors_paths, dtype, &device_for_load) + .context("build VarBuilder over safetensors")? + }; + let model = qwen3_dense::ModelForCausalLM::new(&cfg, vb) + .map_err(|e| anyhow::anyhow!("build Qwen3 dense model: {e}"))?; + Ok(ModelArch::Qwen3Dense(model)) + }) + .await + .context("blocking dense load task panicked")??; + Ok((tokenizer_path, arch)) + } + + /// Resolve a model spec to local GGUF and tokenizer file paths via + /// hf-hub. Downloads on first use; subsequent calls are cached. + async fn resolve_files(&self, spec: &ModelSpec) -> Result<(PathBuf, PathBuf)> { + let api = self.hf_api()?; let repo = api.model(spec.model_id.clone()); let info = repo @@ -406,13 +576,15 @@ impl Harness for CandleHarness { // Stage 7a-i scaffolds tensor-parallel worker subprocesses but // does not yet route inference through them. Refuse TP loads - // for now with a clear marker so the request surface is honest. + // for now with a clear marker so the request surface is honest; + // Stage 7b-iv replaces this bail with the TP dispatch. let tp_size = spec.tensor_parallel.unwrap_or(1); if tp_size > 1 { anyhow::bail!( "tensor_parallel={tp_size} requested for '{}': TP worker \ - lifecycle is in place (Stage 7a-i) but TP-aware Qwen3 \ - inference lands in Stage 7b; single-GPU loads only for now", + lifecycle + NCCL handshake are in place (Stage 7a) but \ + TP-aware Qwen3 inference orchestration lands in Stage \ + 7b-iv; single-GPU loads only for now", spec.model_id ); } @@ -420,44 +592,19 @@ impl Harness for CandleHarness { let devices = spec.devices.clone().unwrap_or_else(|| vec![0]); let device = Self::pick_device(&devices)?; - let (gguf_path, tokenizer_path) = self.resolve_files(spec).await?; + // Dispatch by source format: GGUF (pre-quantized, single-GPU + // only path) vs safetensors dense (bf16/fp16; the path that + // grows TP support). `spec.quant` is the signal — Some means + // the operator picked a quantized GGUF; None means dense. + let (tokenizer_path, arch) = if spec.quant.is_some() { + self.load_arch_gguf(spec, &device).await? + } else { + self.load_arch_dense(spec, &device).await? + }; let tokenizer = Tokenizer::from_file(&tokenizer_path) .map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?; - // File I/O + GGUF parsing + tensor materialisation are CPU-bound, - // so run them on a blocking task to avoid stalling the runtime. - let device_for_load = device.clone(); - let gguf_path_for_load = gguf_path.clone(); - let model_id_for_log = spec.model_id.clone(); - let arch = tokio::task::spawn_blocking(move || -> Result { - tracing::info!(model = %model_id_for_log, path = ?gguf_path_for_load, "loading GGUF"); - let mut file = std::fs::File::open(&gguf_path_for_load).context("open GGUF file")?; - let content = gguf_file::Content::read(&mut file) - .map_err(|e| anyhow::anyhow!("parse GGUF: {e}"))?; - - let architecture = content - .metadata - .get("general.architecture") - .and_then(|v| v.to_string().ok().cloned()) - .unwrap_or_default(); - tracing::info!(architecture = %architecture, "GGUF architecture"); - - match architecture.as_str() { - "qwen3" => { - let weights = - QuantizedQwen3Weights::from_gguf(content, &mut file, &device_for_load) - .map_err(|e| anyhow::anyhow!("from_gguf qwen3: {e}"))?; - Ok(ModelArch::Qwen3Quantized(weights)) - } - other => anyhow::bail!( - "unsupported GGUF architecture '{other}'; Stage 3 only supports qwen3" - ), - } - }) - .await - .context("blocking load task panicked")??; - let loaded = Arc::new(LoadedModel { model_id: spec.model_id.clone(), arch: Arc::new(Mutex::new(arch)), @@ -564,6 +711,13 @@ fn run_inference( let logits = logits.squeeze(0)?; sample_with_penalty(&logits, &generated, &mut logits_processor)? } + ModelArch::Qwen3Dense(model) => { + model.clear_kv_cache(); + let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?; + let logits = model.forward(&input, 0)?; + let logits = logits.squeeze(0)?; + sample_with_penalty(&logits, &generated, &mut logits_processor)? + } }; if Some(next_token) == eos_id { @@ -579,6 +733,12 @@ fn run_inference( let logits = logits.squeeze(0)?; sample_with_penalty(&logits, &generated, &mut logits_processor)? } + ModelArch::Qwen3Dense(model) => { + let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?; + let logits = model.forward(&input, prompt_tokens.len() + index)?; + let logits = logits.squeeze(0)?; + sample_with_penalty(&logits, &generated, &mut logits_processor)? + } }; if Some(next_token) == eos_id { return Ok((generated, "stop".into())); @@ -635,6 +795,13 @@ fn run_inference_streaming( let logits = logits.squeeze(0)?; sample_with_penalty(&logits, &all_tokens, &mut logits_processor)? } + ModelArch::Qwen3Dense(model) => { + model.clear_kv_cache(); + let input = Tensor::new(prompt_tokens, device)?.unsqueeze(0)?; + let logits = model.forward(&input, 0)?; + let logits = logits.squeeze(0)?; + sample_with_penalty(&logits, &all_tokens, &mut logits_processor)? + } }; let emit_token = |all_tokens: &[u32], decoded_prefix: &mut String| -> Result { @@ -683,6 +850,12 @@ fn run_inference_streaming( let logits = logits.squeeze(0)?; sample_with_penalty(&logits, &all_tokens, &mut logits_processor)? } + ModelArch::Qwen3Dense(model) => { + let input = Tensor::new(&[next_token], device)?.unsqueeze(0)?; + let logits = model.forward(&input, prompt_tokens.len() + index)?; + let logits = logits.squeeze(0)?; + sample_with_penalty(&logits, &all_tokens, &mut logits_processor)? + } }; if Some(next_token) == eos_id { finish_reason = "stop".into(); diff --git a/crates/neuron/src/harness/tp/mod.rs b/crates/neuron/src/harness/tp/mod.rs index 46e35d6..c6c33e4 100644 --- a/crates/neuron/src/harness/tp/mod.rs +++ b/crates/neuron/src/harness/tp/mod.rs @@ -213,8 +213,7 @@ impl WorkerPool { // Swap out the leader's NcclState into a fresh empty one so we // can move it into spawn_blocking; restore after the task // returns. (NcclState isn't Clone — it owns a real NCCL Comm.) - let mut leader_state = - std::mem::take(&mut self.leader_nccl); + let mut leader_state = std::mem::take(&mut self.leader_nccl); let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || { let resp = leader_state.init(leader_cfg, &comm_id_for_leader); (leader_state, resp) @@ -269,8 +268,7 @@ impl WorkerPool { // 2. Leader's own all_reduce, in spawn_blocking. NCCL operations // block until every rank participates. - let mut leader_state = - std::mem::take(&mut self.leader_nccl); + let mut leader_state = std::mem::take(&mut self.leader_nccl); let (returned_state, leader_resp) = tokio::task::spawn_blocking(move || { let resp = leader_state.sanity_check(); (leader_state, resp) diff --git a/script/validate-neuron.sh b/script/validate-neuron.sh index 47a892b..b0f974f 100755 --- a/script/validate-neuron.sh +++ b/script/validate-neuron.sh @@ -29,8 +29,8 @@ BASE="http://${HOST}:${PORT}" # Reasoning probe — concrete, low-temperature answer that small models # can still get right. "Paris" is a strong signal of basic competence # beyond gibberish. -PROBE_PROMPT='What is the capital of France? Respond with the city name only, no punctuation.' -EXPECT_SUBSTR='Paris' +PROBE_PROMPT='What is the capital of Georgia (Caucasus)? Respond with the city name only, no punctuation.' +EXPECT_SUBSTR='Tbilisi' # Qwen3 prepends ... reasoning before the answer when the # chat template enables thinking mode, which eats most of a small token # budget. 256 leaves enough room for thinking + final answer. @@ -67,18 +67,22 @@ is_loaded() { } trigger_load() { - say "POST /models/load ${MODEL_ID} (quant=${QUANT}, device=[0])" + say "POST /models/load ${MODEL_ID} (quant=${QUANT:-}, device=[0])" say " (synchronous; may take a minute on first run while HF downloads)" + # Build the payload via jq so the optional `quant` field is + # omitted entirely when empty — that's the signal to the harness + # to take the dense safetensors load path rather than GGUF. local payload - payload=$(cat <