feat(tp): Stage 7b-iv — RPC + orchestration for TP load/inference
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 38s
CI / Format (push) Successful in 40s
CI / Clippy (push) Successful in 2m20s
build-prerelease / Build cortex binary (push) Successful in 4m25s
build-prerelease / Package cortex RPM (push) Successful in 1m22s
CI / Test (push) Successful in 4m34s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m57s
build-prerelease / Build neuron-ampere (push) Successful in 4m51s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m49s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m51s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
All checks were successful
build-prerelease / Resolve version stamps (push) Successful in 38s
CI / Format (push) Successful in 40s
CI / Clippy (push) Successful in 2m20s
build-prerelease / Build cortex binary (push) Successful in 4m25s
build-prerelease / Package cortex RPM (push) Successful in 1m22s
CI / Test (push) Successful in 4m34s
CI / Build cortex SRPM (push) Has been skipped
CI / Publish cortex to COPR (push) Has been skipped
CI / Build neuron SRPM (push) Has been skipped
CI / Publish neuron to COPR (push) Has been skipped
CI / Bump version in source (push) Has been skipped
build-prerelease / Build neuron-blackwell (push) Successful in 3m57s
build-prerelease / Build neuron-ampere (push) Successful in 4m51s
build-prerelease / Build neuron-ada (push) Successful in 5m12s
build-prerelease / Package helexa-neuron-ampere RPM (push) Successful in 2m49s
build-prerelease / Package helexa-neuron-ada RPM (push) Successful in 2m51s
build-prerelease / Package helexa-neuron-blackwell RPM (push) Successful in 3m43s
build-prerelease / Publish to rpm.lair.cafe (unstable) (push) Successful in 1m0s
Wires the in-flight TP machinery (Stage 7a workers, 7b-iii sharded
Qwen3) end to end so a non-streaming chat completion can run across
multiple GPUs via NCCL.
RPC additions (tp/rpc.rs):
- LoadDenseShard{model_id, config_json, safetensors_paths}
- GenerateStep{model_id, tokens, offset}
- ClearKvCache{model_id}
- UnloadModel{model_id}
- LoadDenseShardOk / GenerateStepOk / KvCacheCleared / Unloaded
Worker side (tp/worker.rs):
- WorkerState gains a `models: HashMap<String, TpQwen3ForCausalLM>`
keyed by model_id. LoadDenseShard mmaps safetensors via
ShardedVarBuilder (only this rank's slice materialises), builds the
TP model with the rank's NCCL Comm cloned from NcclState.
- GenerateStep runs the rank-local forward; the resulting logits are
dropped (only the leader's are used for sampling). The forward's
value here is the NCCL collectives inside the row-parallel layers
letting the leader's rank-0 forward make progress.
Pool side (tp/mod.rs):
- WorkerPool::load_dense_shard fans LoadDenseShard out to every worker,
builds rank 0's shard on the leader via spawn_blocking with a fresh
SendComm wrapper at the move boundary (Comm is !Send at the type
level), collects per-rank LoadDenseShardOk. Returns the leader's
Arc<Mutex<TpQwen3ForCausalLM>>.
- WorkerPool::generate_step fans GenerateStep out, runs the leader's
rank-0 forward in spawn_blocking (the AllReduce CustomOps inside
row-parallel layers block until every worker issues the matching
collective), returns the leader's last-position logits Tensor.
- WorkerPool::clear_kv_cache + unload_model follow the same pattern.
NcclState refactor (tp/nccl_state.rs):
- comm field becomes Option<Arc<Comm>> (was Option<Comm>) so callers
can share a clone with TpQwen3ForCausalLM::load.
- new `comm()` accessor + `SendComm` wrapper for spawn_blocking moves.
- single allow(clippy::arc_with_non_send_sync) at the canonical
construction site (Comm is !Send by type but the runtime invariant
is enforced by SendComm + the pool's Mutex).
Harness side (candle.rs):
- LoadedHandle enum (Single | Tp) replaces the bare Arc<LoadedModel>
in the harness's registry. list_models / unload_model /
inference_endpoint walk the enum uniformly.
- TpLoadedModel holds the pool + leader_model + tokenizer + devices.
- load_model dispatches on `spec.tensor_parallel > 1` to a new
cuda-gated load_tp path: resolve dense files via hf-hub, spawn the
pool, init_nccl, load_dense_shard.
- chat_completion branches on the handle variant. The TP path mirrors
run_inference: clear_kv_cache, prefill, sample, decode loop,
detokenize. Acquires the pool Mutex for the whole request.
- Streaming through TP is deferred to Stage 7c (returns Other(err)).
Script (script/validate-neuron.sh):
- 4th positional arg `tp_size` (default 1). When >1, switches to the
dense path (tp + GGUF is mutually exclusive — bails) and adds
`tensor_parallel` + `devices` to the load payload. NEURON_DEVICES
env overrides the default 0..N-1 device list.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -31,11 +31,44 @@ use tokio::sync::{Mutex, RwLock, mpsc};
|
||||
|
||||
/// In-process candle harness. Owns the loaded model registry.
|
||||
pub struct CandleHarness {
|
||||
models: Arc<RwLock<HashMap<String, Arc<LoadedModel>>>>,
|
||||
models: Arc<RwLock<HashMap<String, LoadedHandle>>>,
|
||||
hf_cache: Option<PathBuf>,
|
||||
bind_url: String,
|
||||
}
|
||||
|
||||
/// One entry in the harness's loaded-model registry. Single-GPU loads
|
||||
/// land in `Single`; loads with `tensor_parallel > 1` land in `Tp`.
|
||||
/// The two variants share the same `model_id` key in the map, so
|
||||
/// `list_models`, `unload_model`, and `inference_endpoint` can walk
|
||||
/// them uniformly without branching the storage layout.
|
||||
///
|
||||
/// `Clone` is cheap: both variants hold `Arc<_>` and cloning just bumps
|
||||
/// the refcount.
|
||||
#[derive(Clone)]
|
||||
pub enum LoadedHandle {
|
||||
Single(Arc<LoadedModel>),
|
||||
#[cfg(feature = "cuda")]
|
||||
Tp(Arc<TpLoadedModel>),
|
||||
}
|
||||
|
||||
impl LoadedHandle {
|
||||
pub fn model_id(&self) -> &str {
|
||||
match self {
|
||||
LoadedHandle::Single(m) => &m.model_id,
|
||||
#[cfg(feature = "cuda")]
|
||||
LoadedHandle::Tp(m) => &m.model_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn devices(&self) -> Vec<u32> {
|
||||
match self {
|
||||
LoadedHandle::Single(m) => m.devices.clone(),
|
||||
#[cfg(feature = "cuda")]
|
||||
LoadedHandle::Tp(m) => m.devices.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A loaded model with its tokenizer, device placement, and architecture-
|
||||
/// specific weights. The `arch` is `Arc<Mutex<>>` so the lock guard can be
|
||||
/// moved into `spawn_blocking` for synchronous candle forward passes.
|
||||
@@ -48,6 +81,25 @@ pub struct LoadedModel {
|
||||
pub devices: Vec<u32>,
|
||||
}
|
||||
|
||||
/// Tensor-parallel loaded model. Holds the leader's rank-0 shard
|
||||
/// (which the inference loop drives via spawn_blocking) and the
|
||||
/// `WorkerPool` (which drives every non-zero rank over the RPC
|
||||
/// channel). Both are behind tokio Mutexes so concurrent inference
|
||||
/// requests against the same model are serialised; concurrent loads
|
||||
/// for *different* models would each have their own pool.
|
||||
#[cfg(feature = "cuda")]
|
||||
pub struct TpLoadedModel {
|
||||
pub model_id: String,
|
||||
pub tokenizer: Tokenizer,
|
||||
pub devices: Vec<u32>,
|
||||
/// One end-to-end gate: the pool's RPC stream isn't safe to use
|
||||
/// concurrently and the leader shard's KV cache mutates with every
|
||||
/// step. The same Mutex covers both for the simplest correctness
|
||||
/// story.
|
||||
pub pool: tokio::sync::Mutex<super::tp::WorkerPool>,
|
||||
pub leader_model: Arc<tokio::sync::Mutex<super::tp::tp_qwen3::TpQwen3ForCausalLM>>,
|
||||
}
|
||||
|
||||
/// Architecture-specific weights.
|
||||
///
|
||||
/// - `Qwen3Quantized` — GGUF source, pre-quantized. Single-GPU only;
|
||||
@@ -357,11 +409,22 @@ impl CandleHarness {
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<ChatCompletionResponse, InferenceError> {
|
||||
let loaded = {
|
||||
let handle = {
|
||||
let models = self.models.read().await;
|
||||
models.get(&request.model).cloned()
|
||||
};
|
||||
let loaded = loaded.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
||||
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
||||
// The match is technically infallible without `cuda` (only Single
|
||||
// exists), but the cfg-gated Tp arm makes this the right shape
|
||||
// under both feature flags.
|
||||
#[allow(clippy::infallible_destructuring_match)]
|
||||
let loaded = match handle {
|
||||
LoadedHandle::Single(m) => m,
|
||||
#[cfg(feature = "cuda")]
|
||||
LoadedHandle::Tp(m) => {
|
||||
return self.chat_completion_tp(m, request).await;
|
||||
}
|
||||
};
|
||||
|
||||
let prompt = format_qwen3_prompt(&request.messages);
|
||||
|
||||
@@ -451,11 +514,29 @@ impl CandleHarness {
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<mpsc::Receiver<ChatCompletionChunk>, InferenceError> {
|
||||
let loaded = {
|
||||
let handle = {
|
||||
let models = self.models.read().await;
|
||||
models.get(&request.model).cloned()
|
||||
};
|
||||
let loaded = loaded.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
||||
let handle = handle.ok_or_else(|| InferenceError::ModelNotLoaded(request.model.clone()))?;
|
||||
// The match is technically infallible without `cuda` (only Single
|
||||
// exists), but the cfg-gated Tp arm makes this the right shape
|
||||
// under both feature flags.
|
||||
#[allow(clippy::infallible_destructuring_match)]
|
||||
let loaded = match handle {
|
||||
LoadedHandle::Single(m) => m,
|
||||
#[cfg(feature = "cuda")]
|
||||
LoadedHandle::Tp(_) => {
|
||||
// Streaming through TP is Stage 7c work — the
|
||||
// non-streaming path drives the same forwards through
|
||||
// the pool but doesn't have to interleave SSE writes
|
||||
// with spawn_blocking forwards.
|
||||
return Err(InferenceError::Other(anyhow::anyhow!(
|
||||
"streaming chat completions through TP are not yet supported; \
|
||||
retry with stream=false"
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let prompt = format_qwen3_prompt(&request.messages);
|
||||
let encoding = loaded
|
||||
@@ -552,11 +633,11 @@ impl Harness for CandleHarness {
|
||||
let models = self.models.read().await;
|
||||
Ok(models
|
||||
.values()
|
||||
.map(|m| ModelInfo {
|
||||
id: m.model_id.clone(),
|
||||
.map(|h| ModelInfo {
|
||||
id: h.model_id().into(),
|
||||
harness: "candle".into(),
|
||||
status: "loaded".into(),
|
||||
devices: m.devices.clone(),
|
||||
devices: h.devices(),
|
||||
vram_used_mb: None,
|
||||
})
|
||||
.collect())
|
||||
@@ -574,19 +655,20 @@ impl Harness for CandleHarness {
|
||||
}
|
||||
}
|
||||
|
||||
// Stage 7a-i scaffolds tensor-parallel worker subprocesses but
|
||||
// does not yet route inference through them. Refuse TP loads
|
||||
// for now with a clear marker so the request surface is honest;
|
||||
// Stage 7b-iv replaces this bail with the TP dispatch.
|
||||
let tp_size = spec.tensor_parallel.unwrap_or(1);
|
||||
if tp_size > 1 {
|
||||
anyhow::bail!(
|
||||
"tensor_parallel={tp_size} requested for '{}': TP worker \
|
||||
lifecycle + NCCL handshake are in place (Stage 7a) but \
|
||||
TP-aware Qwen3 inference orchestration lands in Stage \
|
||||
7b-iv; single-GPU loads only for now",
|
||||
spec.model_id
|
||||
);
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
return self.load_tp(spec, tp_size).await;
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
{
|
||||
anyhow::bail!(
|
||||
"tensor_parallel={tp_size} requested for '{}': this neuron \
|
||||
binary was built without --features cuda; TP requires CUDA + NCCL",
|
||||
spec.model_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let devices = spec.devices.clone().unwrap_or_else(|| vec![0]);
|
||||
@@ -615,15 +697,52 @@ impl Harness for CandleHarness {
|
||||
});
|
||||
|
||||
let mut models = self.models.write().await;
|
||||
models.insert(spec.model_id.clone(), loaded);
|
||||
models.insert(spec.model_id.clone(), LoadedHandle::Single(loaded));
|
||||
tracing::info!(model = %spec.model_id, "model loaded");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unload_model(&self, model_id: &str) -> Result<()> {
|
||||
let mut models = self.models.write().await;
|
||||
if models.remove(model_id).is_none() {
|
||||
let removed = {
|
||||
let mut models = self.models.write().await;
|
||||
models.remove(model_id)
|
||||
};
|
||||
let Some(handle) = removed else {
|
||||
anyhow::bail!("model '{model_id}' not loaded");
|
||||
};
|
||||
// Single-GPU drops are immediate — the LoadedModel goes out of
|
||||
// scope with the Arc and candle frees VRAM. TP unloads also
|
||||
// need to tell every worker to drop its shard before the pool
|
||||
// itself is dropped (otherwise the workers keep their shards
|
||||
// around until Shutdown, which is wasteful and would surface
|
||||
// as VRAM not freed promptly).
|
||||
match handle {
|
||||
LoadedHandle::Single(_) => {}
|
||||
#[cfg(feature = "cuda")]
|
||||
LoadedHandle::Tp(tp) => {
|
||||
// Try to recover the inner TpLoadedModel so we can move
|
||||
// the pool and shut it down. If anyone else still holds
|
||||
// a clone of the Arc (shouldn't happen — the only owners
|
||||
// are the registry and any in-flight chat_completion),
|
||||
// bail with a clear marker rather than silently leaking.
|
||||
let tp = match Arc::try_unwrap(tp) {
|
||||
Ok(t) => t,
|
||||
Err(arc) => {
|
||||
// Reinsert so we don't leave the registry in an
|
||||
// inconsistent state.
|
||||
let mut models = self.models.write().await;
|
||||
models.insert(model_id.into(), LoadedHandle::Tp(arc));
|
||||
anyhow::bail!("cannot unload '{model_id}': inference still in flight");
|
||||
}
|
||||
};
|
||||
let mut pool = tp.pool.into_inner();
|
||||
if let Err(e) = pool.unload_model(model_id).await {
|
||||
tracing::warn!(model = %model_id, error = %e, "TP unload RPC failed");
|
||||
}
|
||||
if let Err(e) = pool.shutdown().await {
|
||||
tracing::warn!(model = %model_id, error = %e, "TP pool shutdown failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
tracing::info!(model = %model_id, "model unloaded");
|
||||
Ok(())
|
||||
@@ -635,6 +754,215 @@ impl Harness for CandleHarness {
|
||||
}
|
||||
}
|
||||
|
||||
impl CandleHarness {
|
||||
/// Tensor-parallel load. Resolves dense safetensors via hf-hub the
|
||||
/// same way the single-GPU dense path does, spins up a TP worker
|
||||
/// pool sized to `tp_size`, runs the NCCL handshake, then has
|
||||
/// every rank load its shard of the model.
|
||||
///
|
||||
/// `spec.devices` carries the per-rank CUDA device indices (one
|
||||
/// entry per rank, in rank order); defaults to `0..tp_size`.
|
||||
#[cfg(feature = "cuda")]
|
||||
async fn load_tp(&self, spec: &ModelSpec, tp_size: u32) -> Result<()> {
|
||||
use std::sync::Arc as StdArc;
|
||||
use tokio::sync::Mutex as TMutex;
|
||||
|
||||
// Default per-rank device assignment: 0, 1, ..., tp_size - 1.
|
||||
let devices = spec
|
||||
.devices
|
||||
.clone()
|
||||
.unwrap_or_else(|| (0..tp_size).collect());
|
||||
if devices.len() as u32 != tp_size {
|
||||
anyhow::bail!(
|
||||
"tensor_parallel={tp_size} requires {tp_size} entries in devices, got {}",
|
||||
devices.len()
|
||||
);
|
||||
}
|
||||
if spec.quant.is_some() {
|
||||
anyhow::bail!(
|
||||
"tensor_parallel={tp_size} with quant={:?}: GGUF quantized models \
|
||||
are not supported in the TP path; use a dense safetensors source",
|
||||
spec.quant
|
||||
);
|
||||
}
|
||||
|
||||
// 1. Resolve config + tokenizer + safetensors via hf-hub.
|
||||
let (config_path, tokenizer_path, safetensors_paths) =
|
||||
self.resolve_dense_files(spec).await?;
|
||||
let config_json = std::fs::read_to_string(&config_path).context("read config.json")?;
|
||||
|
||||
// 2. Spawn the worker pool. Rank 0 stays in-process; ranks
|
||||
// 1..tp_size are subprocesses, one per device after the
|
||||
// leader's own.
|
||||
let exe = std::env::current_exe().context("resolve current_exe for worker spawn")?;
|
||||
let mut pool = super::tp::WorkerPool::spawn(&exe, tp_size, &devices).await?;
|
||||
|
||||
// 3. NCCL handshake across all ranks.
|
||||
let leader_device_idx = devices[0];
|
||||
pool.init_nccl(leader_device_idx).await?;
|
||||
|
||||
// 4. Pick the leader's candle Device (same index as init_nccl).
|
||||
let leader_device = candle_core::Device::new_cuda(leader_device_idx as usize)
|
||||
.context("Device::new_cuda for TP leader")?;
|
||||
|
||||
// 5. Load this rank's shard on every rank.
|
||||
let leader_model = pool
|
||||
.load_dense_shard(
|
||||
&spec.model_id,
|
||||
&config_json,
|
||||
&safetensors_paths,
|
||||
&leader_device,
|
||||
candle_core::DType::BF16,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// 6. Tokenizer (same as single-GPU path).
|
||||
let tokenizer = Tokenizer::from_file(&tokenizer_path)
|
||||
.map_err(|e| anyhow::anyhow!("load tokenizer: {e}"))?;
|
||||
|
||||
let tp_loaded = StdArc::new(TpLoadedModel {
|
||||
model_id: spec.model_id.clone(),
|
||||
tokenizer,
|
||||
devices: devices.clone(),
|
||||
pool: TMutex::new(pool),
|
||||
leader_model,
|
||||
});
|
||||
|
||||
let mut models = self.models.write().await;
|
||||
models.insert(spec.model_id.clone(), LoadedHandle::Tp(tp_loaded));
|
||||
tracing::info!(
|
||||
model = %spec.model_id,
|
||||
tp_size,
|
||||
?devices,
|
||||
"TP model loaded"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Non-streaming chat completion against a TP model. Pattern mirrors
|
||||
/// the single-GPU `run_inference`: tokenize, prefill, sample, decode
|
||||
/// loop, detokenize. Each forward step fans out to every rank via
|
||||
/// the WorkerPool and uses the leader's last-position logits to
|
||||
/// sample.
|
||||
#[cfg(feature = "cuda")]
|
||||
async fn chat_completion_tp(
|
||||
&self,
|
||||
tp: Arc<TpLoadedModel>,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<ChatCompletionResponse, InferenceError> {
|
||||
let prompt = format_qwen3_prompt(&request.messages);
|
||||
let encoding = tp
|
||||
.tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("tokenize: {e}")))?;
|
||||
let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
|
||||
let prompt_len = prompt_tokens.len();
|
||||
|
||||
let temperature = request.temperature.unwrap_or(0.7);
|
||||
let top_p = request.top_p;
|
||||
let max_new = request.max_tokens.unwrap_or(512) as usize;
|
||||
let seed = unix_subsec_nanos();
|
||||
|
||||
let eos_id = tp
|
||||
.tokenizer
|
||||
.token_to_id("<|im_end|>")
|
||||
.or_else(|| tp.tokenizer.token_to_id("<|endoftext|>"));
|
||||
|
||||
let model_id = request.model.clone();
|
||||
|
||||
// Acquire the pool lock for the duration of the request. The
|
||||
// leader_model's own Mutex is acquired step-by-step inside
|
||||
// pool.generate_step (so spawn_blocking can grab it without
|
||||
// holding the pool lock across the blocking_lock call).
|
||||
let mut pool = tp.pool.lock().await;
|
||||
let leader_arc = tp.leader_model.clone();
|
||||
|
||||
// Reset every rank's KV cache so this request doesn't attend
|
||||
// over the previous request's tokens.
|
||||
pool.clear_kv_cache(&model_id, leader_arc.clone())
|
||||
.await
|
||||
.map_err(InferenceError::Other)?;
|
||||
|
||||
let mut logits_processor = {
|
||||
let sampling = if temperature <= 0.0 {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match top_p {
|
||||
Some(p) => Sampling::TopP { p, temperature },
|
||||
None => Sampling::All { temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(seed, sampling)
|
||||
};
|
||||
|
||||
let mut generated: Vec<u32> = Vec::new();
|
||||
let mut finish_reason = "length".to_string();
|
||||
|
||||
// Prefill: every rank embeds the whole prompt, offset = 0.
|
||||
let logits = pool
|
||||
.generate_step(&model_id, leader_arc.clone(), prompt_tokens.clone(), 0)
|
||||
.await
|
||||
.map_err(InferenceError::Other)?;
|
||||
let mut next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
|
||||
.map_err(InferenceError::Other)?;
|
||||
|
||||
if Some(next_token) == eos_id {
|
||||
finish_reason = "stop".into();
|
||||
} else {
|
||||
generated.push(next_token);
|
||||
for index in 0..max_new.saturating_sub(1) {
|
||||
let logits = pool
|
||||
.generate_step(
|
||||
&model_id,
|
||||
leader_arc.clone(),
|
||||
vec![next_token],
|
||||
prompt_len + index,
|
||||
)
|
||||
.await
|
||||
.map_err(InferenceError::Other)?;
|
||||
next_token = sample_with_penalty(&logits, &generated, &mut logits_processor)
|
||||
.map_err(InferenceError::Other)?;
|
||||
if Some(next_token) == eos_id {
|
||||
finish_reason = "stop".into();
|
||||
break;
|
||||
}
|
||||
generated.push(next_token);
|
||||
}
|
||||
}
|
||||
drop(pool);
|
||||
|
||||
let completion_text = tp
|
||||
.tokenizer
|
||||
.decode(&generated, true)
|
||||
.map_err(|e| InferenceError::Other(anyhow::anyhow!("detokenize: {e}")))?;
|
||||
|
||||
let usage = Usage {
|
||||
prompt_tokens: prompt_len as u64,
|
||||
completion_tokens: generated.len() as u64,
|
||||
total_tokens: (prompt_len + generated.len()) as u64,
|
||||
};
|
||||
|
||||
Ok(ChatCompletionResponse {
|
||||
id: format!("chatcmpl-{:x}", unix_subsec_nanos()),
|
||||
object: "chat.completion".into(),
|
||||
created: unix_now_secs(),
|
||||
model: model_id,
|
||||
choices: vec![ChatCompletionChoice {
|
||||
index: 0,
|
||||
message: ChatMessage {
|
||||
role: "assistant".into(),
|
||||
content: MessageContent::Text(completion_text),
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
},
|
||||
finish_reason: Some(finish_reason),
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
}],
|
||||
usage: Some(usage),
|
||||
extra: serde_json::Value::Object(Default::default()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors returned by `CandleHarness::chat_completion`. The
|
||||
/// `ModelNotLoaded` variant lets the HTTP handler map cleanly to 404
|
||||
/// without string-matching on anyhow messages.
|
||||
|
||||
Reference in New Issue
Block a user