Add cross-run learning via run ledger and compare endpoint
Persist strategy + run_id to results/run_ledger.jsonl after each backtest. On startup, load the ledger, fetch metrics via the new compare endpoint (batched in groups of 50), group by strategy, rank by avg Sharpe, and inject a summary of the top 5 and worst 3 prior strategies into the iteration-1 prompt. Also consumes the enriched result_summary fields from swym patch e47c18: sortino_ratio, calmar_ratio, max_drawdown, pnl_return, avg_win, avg_loss, max_win, max_loss, avg_hold_duration_secs. Sortino and max_drawdown are appended to summary_line() when present. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
196
src/agent.rs
196
src/agent.rs
@@ -1,14 +1,26 @@
|
||||
use std::io::Write as IoWrite;
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tracing::{debug, error, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::claude::{self, ClaudeClient, Message};
|
||||
use crate::config::{Cli, Instrument};
|
||||
use crate::prompts;
|
||||
use crate::swym::{BacktestResult, SwymClient};
|
||||
use crate::swym::{BacktestResult, RunMetricsSummary, SwymClient};
|
||||
|
||||
/// Persistent record of a single completed backtest, written to the run ledger.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct LedgerEntry {
|
||||
run_id: Uuid,
|
||||
instrument: String,
|
||||
candle_interval: String,
|
||||
strategy: Value,
|
||||
}
|
||||
|
||||
/// A single iteration's record: strategy + results across instruments.
|
||||
#[derive(Debug)]
|
||||
@@ -193,6 +205,9 @@ pub async fn run(cli: &Cli) -> Result<()> {
|
||||
let system = prompts::system_prompt(schema, claude.family());
|
||||
info!("model family: {}", claude.family().name());
|
||||
|
||||
// Load prior runs from ledger and build cross-run context for iteration 1
|
||||
let prior_summary = load_prior_summary(&cli.output_dir, &swym).await;
|
||||
|
||||
// Agent state
|
||||
let mut history: Vec<IterationRecord> = Vec::new();
|
||||
let mut conversation: Vec<Message> = Vec::new();
|
||||
@@ -206,7 +221,7 @@ pub async fn run(cli: &Cli) -> Result<()> {
|
||||
|
||||
// Build the user prompt
|
||||
let user_msg = if iteration == 1 {
|
||||
prompts::initial_prompt(&instrument_names, &available_intervals)
|
||||
prompts::initial_prompt(&instrument_names, &available_intervals, prior_summary.as_deref())
|
||||
} else {
|
||||
let results_text = history
|
||||
.iter()
|
||||
@@ -397,12 +412,13 @@ pub async fn run(cli: &Cli) -> Result<()> {
|
||||
info!(" condition audit: {}", serde_json::to_string_pretty(audit).unwrap_or_default());
|
||||
}
|
||||
}
|
||||
append_ledger_entry(&cli.output_dir, &result, &strategy);
|
||||
results.push(result);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(" backtest failed for {}: {e:#}", inst.symbol);
|
||||
results.push(BacktestResult {
|
||||
run_id: uuid::Uuid::nil(),
|
||||
run_id: Uuid::nil(),
|
||||
instrument: inst.symbol.clone(),
|
||||
status: "failed".to_string(),
|
||||
total_positions: None,
|
||||
@@ -413,6 +429,15 @@ pub async fn run(cli: &Cli) -> Result<()> {
|
||||
total_pnl: None,
|
||||
net_pnl: None,
|
||||
sharpe_ratio: None,
|
||||
sortino_ratio: None,
|
||||
calmar_ratio: None,
|
||||
max_drawdown: None,
|
||||
pnl_return: None,
|
||||
avg_win: None,
|
||||
avg_loss: None,
|
||||
max_win: None,
|
||||
max_loss: None,
|
||||
avg_hold_duration_secs: None,
|
||||
total_fees: None,
|
||||
avg_bars_in_trade: None,
|
||||
error_message: Some(e.to_string()),
|
||||
@@ -573,6 +598,171 @@ async fn run_single_backtest(
|
||||
Ok(BacktestResult::from_response(&final_resp, &inst.symbol))
|
||||
}
|
||||
|
||||
/// Append a ledger entry for a completed backtest so future runs can learn from it.
|
||||
fn append_ledger_entry(output_dir: &Path, result: &BacktestResult, strategy: &Value) {
|
||||
// Skip nil run_ids (error placeholders)
|
||||
if result.run_id == Uuid::nil() {
|
||||
return;
|
||||
}
|
||||
let entry = LedgerEntry {
|
||||
run_id: result.run_id,
|
||||
instrument: result.instrument.clone(),
|
||||
candle_interval: strategy["candle_interval"]
|
||||
.as_str()
|
||||
.unwrap_or("?")
|
||||
.to_string(),
|
||||
strategy: strategy.clone(),
|
||||
};
|
||||
let line = match serde_json::to_string(&entry) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
warn!("could not serialize ledger entry: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let path = output_dir.join("run_ledger.jsonl");
|
||||
if let Err(e) = std::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.create(true)
|
||||
.open(&path)
|
||||
.and_then(|mut f| writeln!(f, "{}", line))
|
||||
{
|
||||
warn!("could not write ledger entry: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the run ledger, fetch metrics via the compare endpoint, and return a compact
|
||||
/// prior-results summary string for the initial prompt. Returns `None` if the ledger
|
||||
/// is absent, empty, or the compare call fails.
|
||||
async fn load_prior_summary(output_dir: &Path, swym: &SwymClient) -> Option<String> {
|
||||
let path = output_dir.join("run_ledger.jsonl");
|
||||
let contents = std::fs::read_to_string(&path).ok()?;
|
||||
|
||||
// Parse all ledger entries
|
||||
let entries: Vec<LedgerEntry> = contents
|
||||
.lines()
|
||||
.filter(|l| !l.trim().is_empty())
|
||||
.filter_map(|l| serde_json::from_str(l).ok())
|
||||
.collect();
|
||||
if entries.is_empty() {
|
||||
return None;
|
||||
}
|
||||
info!("loaded {} ledger entries from previous runs", entries.len());
|
||||
|
||||
// Fetch metrics for all run_ids
|
||||
let run_ids: Vec<Uuid> = entries.iter().map(|e| e.run_id).collect();
|
||||
let metrics = match swym.compare_runs(&run_ids).await {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
warn!("could not fetch prior run metrics: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
// Build a map from run_id → metrics
|
||||
let metrics_map: std::collections::HashMap<Uuid, &RunMetricsSummary> =
|
||||
metrics.iter().map(|m| (m.id, m)).collect();
|
||||
|
||||
// Group entries by strategy (use candle_interval + rules fingerprint)
|
||||
// We use the full strategy JSON as the grouping key.
|
||||
let mut strategy_groups: std::collections::HashMap<String, Vec<(&LedgerEntry, Option<&RunMetricsSummary>)>> =
|
||||
std::collections::HashMap::new();
|
||||
for entry in &entries {
|
||||
let key = serde_json::to_string(&entry.strategy).unwrap_or_default();
|
||||
let m = metrics_map.get(&entry.run_id).copied();
|
||||
strategy_groups.entry(key).or_default().push((entry, m));
|
||||
}
|
||||
|
||||
// Compute avg sharpe per strategy group
|
||||
let mut strategies: Vec<(f64, &Value, Vec<(&LedgerEntry, Option<&RunMetricsSummary>)>)> = strategy_groups
|
||||
.into_values()
|
||||
.map(|group| {
|
||||
let sharpes: Vec<f64> = group
|
||||
.iter()
|
||||
.filter_map(|(_, m)| m.and_then(|m| m.sharpe_ratio))
|
||||
.collect();
|
||||
let avg_sharpe = if sharpes.is_empty() {
|
||||
f64::NEG_INFINITY
|
||||
} else {
|
||||
sharpes.iter().sum::<f64>() / sharpes.len() as f64
|
||||
};
|
||||
let strategy = &group[0].0.strategy;
|
||||
(avg_sharpe, strategy, group)
|
||||
})
|
||||
.collect();
|
||||
strategies.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
let total_strategies = strategies.len();
|
||||
let total_backtests = entries.len();
|
||||
|
||||
// Build summary text — top 5 + bottom 3 (if distinct), capped at ~2000 chars
|
||||
let mut lines = vec![format!(
|
||||
"## Learnings from {} prior backtests across {} strategies\n",
|
||||
total_backtests, total_strategies
|
||||
)];
|
||||
lines.push("### Best strategies (ranked by avg Sharpe):".to_string());
|
||||
|
||||
let show_top = strategies.len().min(5);
|
||||
for (avg_sharpe, strategy, group) in strategies.iter().take(show_top) {
|
||||
let interval = strategy["candle_interval"].as_str().unwrap_or("?");
|
||||
let rule_count = strategy["rules"].as_array().map(|r| r.len()).unwrap_or(0);
|
||||
// Collect per-instrument metrics
|
||||
let inst_lines: Vec<String> = group
|
||||
.iter()
|
||||
.filter_map(|(entry, m)| {
|
||||
let m = (*m)?;
|
||||
Some(format!(
|
||||
" {}: trades={} sharpe={:.3} net_pnl={:.2}{}",
|
||||
entry.instrument,
|
||||
m.total_positions.unwrap_or(0),
|
||||
m.sharpe_ratio.unwrap_or(0.0),
|
||||
m.net_pnl.unwrap_or(0.0),
|
||||
m.max_drawdown.map(|d| format!(" max_dd={:.1}%", d * 100.0)).unwrap_or_default(),
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
// Pull the first rule comment as a strategy description
|
||||
let description = strategy["rules"][0]["comment"]
|
||||
.as_str()
|
||||
.unwrap_or("(no description)");
|
||||
lines.push(format!(
|
||||
"\n [{interval}, {rule_count} rules, avg_sharpe={avg_sharpe:.3}] {description}"
|
||||
));
|
||||
lines.extend(inst_lines);
|
||||
// Include full JSON only for the top 2
|
||||
let rank = strategies.iter().position(|(_, s, _)| std::ptr::eq(*s, *strategy)).unwrap_or(99);
|
||||
if rank < 2 {
|
||||
lines.push(format!(
|
||||
" strategy JSON: {}",
|
||||
serde_json::to_string(strategy).unwrap_or_default()
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Worst 3 (if we have more than 5)
|
||||
if strategies.len() > 5 {
|
||||
lines.push("\n### Worst strategies (avoid repeating these):".to_string());
|
||||
let worst_start = strategies.len().saturating_sub(3);
|
||||
for (avg_sharpe, strategy, _) in strategies.iter().skip(worst_start) {
|
||||
let interval = strategy["candle_interval"].as_str().unwrap_or("?");
|
||||
let description = strategy["rules"][0]["comment"].as_str().unwrap_or("(no description)");
|
||||
lines.push(format!(" [{interval}, avg_sharpe={avg_sharpe:.3}] {description}"));
|
||||
}
|
||||
}
|
||||
|
||||
lines.push(format!(
|
||||
"\nUse these results to avoid repeating failed approaches and build on what worked.\n"
|
||||
));
|
||||
|
||||
let summary = lines.join("\n");
|
||||
// Truncate to ~6000 chars to stay within prompt budget
|
||||
if summary.len() > 6000 {
|
||||
Some(format!("{}…\n[truncated — {} total strategies]\n", &summary[..5900], total_strategies))
|
||||
} else {
|
||||
Some(summary)
|
||||
}
|
||||
}
|
||||
|
||||
fn save_validated_strategy(
|
||||
output_dir: &Path,
|
||||
iteration: u32,
|
||||
|
||||
@@ -493,9 +493,14 @@ CRITICAL: `apply_func` uses `"input"`, not `"expr"`. Writing `"expr":` will be r
|
||||
}
|
||||
|
||||
/// Build the user message for the first iteration (no prior results).
|
||||
pub fn initial_prompt(instruments: &[String], candle_intervals: &[String]) -> String {
|
||||
/// `prior_summary` contains a formatted summary of results from previous runs, if any.
|
||||
pub fn initial_prompt(instruments: &[String], candle_intervals: &[String], prior_summary: Option<&str>) -> String {
|
||||
let prior_section = match prior_summary {
|
||||
Some(s) => format!("{s}\n\n"),
|
||||
None => String::new(),
|
||||
};
|
||||
format!(
|
||||
r#"Design a trading strategy for crypto spot markets.
|
||||
r#"{prior_section}Design a trading strategy for crypto spot markets.
|
||||
|
||||
Available instruments: {}
|
||||
Available candle intervals: {}
|
||||
|
||||
92
src/swym.rs
92
src/swym.rs
@@ -49,6 +49,37 @@ pub struct CandleCoverage {
|
||||
pub coverage_pct: Option<f64>,
|
||||
}
|
||||
|
||||
/// Response from `GET /api/v1/paper-runs/compare?ids=...`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RunMetricsSummary {
|
||||
pub id: Uuid,
|
||||
pub status: String,
|
||||
pub candle_interval: Option<String>,
|
||||
pub total_positions: Option<u32>,
|
||||
#[serde(default, deserialize_with = "deserialize_opt_number")]
|
||||
pub win_rate: Option<f64>,
|
||||
#[serde(default, deserialize_with = "deserialize_opt_number")]
|
||||
pub profit_factor: Option<f64>,
|
||||
#[serde(default, deserialize_with = "deserialize_opt_number")]
|
||||
pub net_pnl: Option<f64>,
|
||||
#[serde(default, deserialize_with = "deserialize_opt_number")]
|
||||
pub sharpe_ratio: Option<f64>,
|
||||
#[serde(default, deserialize_with = "deserialize_opt_number")]
|
||||
pub sortino_ratio: Option<f64>,
|
||||
#[serde(default, deserialize_with = "deserialize_opt_number")]
|
||||
pub calmar_ratio: Option<f64>,
|
||||
#[serde(default, deserialize_with = "deserialize_opt_number")]
|
||||
pub max_drawdown: Option<f64>,
|
||||
#[serde(default, deserialize_with = "deserialize_opt_number")]
|
||||
pub pnl_return: Option<f64>,
|
||||
#[serde(default, deserialize_with = "deserialize_opt_number")]
|
||||
pub avg_win: Option<f64>,
|
||||
#[serde(default, deserialize_with = "deserialize_opt_number")]
|
||||
pub avg_loss: Option<f64>,
|
||||
#[serde(default, deserialize_with = "deserialize_opt_number")]
|
||||
pub avg_hold_duration_secs: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BacktestResult {
|
||||
pub run_id: Uuid,
|
||||
@@ -62,6 +93,15 @@ pub struct BacktestResult {
|
||||
pub total_pnl: Option<f64>,
|
||||
pub net_pnl: Option<f64>,
|
||||
pub sharpe_ratio: Option<f64>,
|
||||
pub sortino_ratio: Option<f64>,
|
||||
pub calmar_ratio: Option<f64>,
|
||||
pub max_drawdown: Option<f64>,
|
||||
pub pnl_return: Option<f64>,
|
||||
pub avg_win: Option<f64>,
|
||||
pub avg_loss: Option<f64>,
|
||||
pub max_win: Option<f64>,
|
||||
pub max_loss: Option<f64>,
|
||||
pub avg_hold_duration_secs: Option<f64>,
|
||||
pub total_fees: Option<f64>,
|
||||
pub avg_bars_in_trade: Option<f64>,
|
||||
pub error_message: Option<String>,
|
||||
@@ -89,6 +129,15 @@ impl BacktestResult {
|
||||
let net_pnl = summary.and_then(|s| parse_number(&s["net_pnl"]));
|
||||
let total_pnl = summary.and_then(|s| parse_number(&s["total_pnl"]));
|
||||
let sharpe_ratio = summary.and_then(|s| parse_number(&s["sharpe_ratio"]));
|
||||
let sortino_ratio = summary.and_then(|s| parse_number(&s["sortino_ratio"]));
|
||||
let calmar_ratio = summary.and_then(|s| parse_number(&s["calmar_ratio"]));
|
||||
let max_drawdown = summary.and_then(|s| parse_number(&s["max_drawdown"]));
|
||||
let pnl_return = summary.and_then(|s| parse_number(&s["pnl_return"]));
|
||||
let avg_win = summary.and_then(|s| parse_number(&s["avg_win"]));
|
||||
let avg_loss = summary.and_then(|s| parse_number(&s["avg_loss"]));
|
||||
let max_win = summary.and_then(|s| parse_number(&s["max_win"]));
|
||||
let max_loss = summary.and_then(|s| parse_number(&s["max_loss"]));
|
||||
let avg_hold_duration_secs = summary.and_then(|s| parse_number(&s["avg_hold_duration_secs"]));
|
||||
let total_fees = summary.and_then(|s| parse_number(&s["total_fees"]));
|
||||
|
||||
Self {
|
||||
@@ -103,6 +152,15 @@ impl BacktestResult {
|
||||
total_pnl,
|
||||
net_pnl,
|
||||
sharpe_ratio,
|
||||
sortino_ratio,
|
||||
calmar_ratio,
|
||||
max_drawdown,
|
||||
pnl_return,
|
||||
avg_win,
|
||||
avg_loss,
|
||||
max_win,
|
||||
max_loss,
|
||||
avg_hold_duration_secs,
|
||||
total_fees,
|
||||
avg_bars_in_trade: None,
|
||||
error_message: resp.error_message.clone(),
|
||||
@@ -128,6 +186,12 @@ impl BacktestResult {
|
||||
self.net_pnl.unwrap_or(0.0),
|
||||
self.sharpe_ratio.unwrap_or(0.0),
|
||||
);
|
||||
if let Some(sortino) = self.sortino_ratio {
|
||||
s.push_str(&format!(" sortino={:.2}", sortino));
|
||||
}
|
||||
if let Some(dd) = self.max_drawdown {
|
||||
s.push_str(&format!(" max_dd={:.1}%", dd * 100.0));
|
||||
}
|
||||
if self.total_positions.unwrap_or(0) == 0 {
|
||||
if let Some(audit) = &self.condition_audit_summary {
|
||||
let audit_str = format_audit_summary(audit);
|
||||
@@ -160,6 +224,15 @@ fn parse_number(v: &Value) -> Option<f64> {
|
||||
if f.abs() > 1e20 { None } else { Some(f) }
|
||||
}
|
||||
|
||||
/// Serde deserializer for `Option<f64>` that accepts both JSON numbers and decimal strings.
|
||||
fn deserialize_opt_number<'de, D>(deserializer: D) -> Result<Option<f64>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let v = Value::deserialize(deserializer)?;
|
||||
Ok(parse_number(&v))
|
||||
}
|
||||
|
||||
/// Render a condition_audit_summary Value into a compact one-line string.
|
||||
///
|
||||
/// Handles the primary shape from the swym API:
|
||||
@@ -386,6 +459,25 @@ impl SwymClient {
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch metrics for multiple completed runs via the compare endpoint.
|
||||
/// Batches requests in groups of 50 (API maximum).
|
||||
pub async fn compare_runs(&self, run_ids: &[Uuid]) -> Result<Vec<RunMetricsSummary>> {
|
||||
let mut results = Vec::new();
|
||||
for chunk in run_ids.chunks(50) {
|
||||
let ids = chunk.iter().map(|id| id.to_string()).collect::<Vec<_>>().join(",");
|
||||
let url = format!("{}/paper-runs/compare?ids={}", self.base_url, ids);
|
||||
let resp = self.client.get(&url).send().await.context("compare runs request")?;
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("compare runs {status}: {body}");
|
||||
}
|
||||
let mut batch: Vec<RunMetricsSummary> = resp.json().await.context("parse compare response")?;
|
||||
results.append(&mut batch);
|
||||
}
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Fetch condition audit summary for a completed run.
|
||||
pub async fn condition_audit(&self, run_id: Uuid) -> Result<Value> {
|
||||
let url = format!("{}/paper-runs/{}/condition-audit", self.base_url, run_id);
|
||||
|
||||
Reference in New Issue
Block a user