Stage 3 of the candle-native pivot. neuron now serves POST /v1/chat/completions backed by candle's quantized_qwen3 forward pass on a per-model serialised generation loop, returning the standard OpenAI ChatCompletionResponse envelope. Pipeline per request: - Look up the LoadedModel by request.model (404 if absent). - Apply the Qwen3 chat template across all messages. - Tokenize, then spawn_blocking onto tokio's blocking pool to acquire the per-model arch lock and run prefill + greedy/temperature/top-p sampling via LogitsProcessor. - Stop on <|im_end|>/<|endoftext|> EOS or max_tokens (finish_reason "stop" vs "length"). - Decode with skip_special_tokens=true, build OpenAI response with prompt/completion/total usage counts. Supporting changes: - HarnessRegistry now stores Arc<dyn Harness> and caches a typed Arc<CandleHarness> so inference routes bypass dyn-Trait dispatch. - LoadedModel.arch becomes Arc<Mutex<ModelArch>> so the lock guard can be moved into spawn_blocking. - NeuronState gains an Option<Arc<CandleHarness>> field for the new inference route. - Typed InferenceError lets the handler map ModelNotLoaded → 404 and other failures → 500 without string-matching anyhow messages. - stream=true returns 501 until Stage 4 wires up SSE. - Two leftover mistral.rs string references in proxy.rs and cortex-cli (missed during the Stage 1 sweep) are corrected here. Three new default-feature tests cover the no-candle 503, model-not- loaded 404, and stream=true 501 paths. The cuda-integration test from Stage 2 still covers real load/unload; a streaming-feature gated test exercising actual generation will arrive with Stage 4. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
112 lines
3.1 KiB
Rust
112 lines
3.1 KiB
Rust
use anyhow::Result;
|
|
use clap::{Parser, Subcommand};
|
|
use cortex_core::config::GatewayConfig;
|
|
use tracing_subscriber::EnvFilter;
|
|
|
|
#[derive(Parser)]
|
|
#[command(name = "cortex")]
|
|
#[command(about = "Unified inference gateway for multi-node GPU clusters")]
|
|
#[command(version)]
|
|
struct Cli {
|
|
#[command(subcommand)]
|
|
command: Commands,
|
|
}
|
|
|
|
#[derive(Subcommand)]
|
|
enum Commands {
|
|
/// Start the gateway server.
|
|
Serve {
|
|
/// Path to the gateway config file.
|
|
#[arg(short, long, default_value = "cortex.toml")]
|
|
config: String,
|
|
},
|
|
/// Print the fleet status (models, nodes, health).
|
|
Status {
|
|
/// Gateway API endpoint to query.
|
|
#[arg(short, long, default_value = "http://localhost:31313")]
|
|
endpoint: String,
|
|
},
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<()> {
|
|
// Initialize tracing with env filter (e.g. RUST_LOG=cortex_gateway=debug).
|
|
tracing_subscriber::fmt()
|
|
.with_env_filter(
|
|
EnvFilter::try_from_default_env()
|
|
.unwrap_or_else(|_| EnvFilter::new("info,cortex_gateway=debug")),
|
|
)
|
|
.init();
|
|
|
|
let cli = Cli::parse();
|
|
|
|
match cli.command {
|
|
Commands::Serve { config } => {
|
|
let cfg = GatewayConfig::load(&config)
|
|
.map_err(|e| anyhow::anyhow!("failed to load config from '{config}': {e}"))?;
|
|
|
|
tracing::info!(
|
|
neurons = cfg.neurons.len(),
|
|
listen = %cfg.gateway.listen,
|
|
"starting cortex"
|
|
);
|
|
|
|
// Install Prometheus metrics exporter on a separate port.
|
|
cortex_gateway::metrics::install(&cfg.gateway.metrics_listen)?;
|
|
|
|
cortex_gateway::run(cfg).await?;
|
|
}
|
|
Commands::Status { endpoint } => {
|
|
print_status(&endpoint).await?;
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn print_status(endpoint: &str) -> Result<()> {
|
|
let client = reqwest::Client::new();
|
|
|
|
// Fetch health.
|
|
let health: serde_json::Value = client
|
|
.get(format!("{endpoint}/health"))
|
|
.send()
|
|
.await?
|
|
.json()
|
|
.await?;
|
|
|
|
println!("Fleet health: {}", serde_json::to_string_pretty(&health)?);
|
|
|
|
// Fetch models.
|
|
let models: serde_json::Value = client
|
|
.get(format!("{endpoint}/v1/models"))
|
|
.send()
|
|
.await?
|
|
.json()
|
|
.await?;
|
|
|
|
println!("\nModels:");
|
|
if let Some(data) = models.get("data").and_then(|d| d.as_array()) {
|
|
for model in data {
|
|
let id = model.get("id").and_then(|v| v.as_str()).unwrap_or("?");
|
|
let locations = model
|
|
.get("locations")
|
|
.and_then(|v| v.as_array())
|
|
.map(|arr| {
|
|
arr.iter()
|
|
.filter_map(|l| {
|
|
let node = l.get("node")?.as_str()?;
|
|
let status = l.get("status")?.as_str()?;
|
|
Some(format!("{node}({status})"))
|
|
})
|
|
.collect::<Vec<_>>()
|
|
.join(", ")
|
|
})
|
|
.unwrap_or_default();
|
|
println!(" {id:40} {locations}");
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|