diff --git a/src/agent.rs b/src/agent.rs index bcb8988..baf0514 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -189,7 +189,8 @@ pub async fn run(cli: &Cli) -> Result<()> { // Load DSL schema for the system prompt let schema = include_str!("dsl-schema.json"); - let system = prompts::system_prompt(schema); + let system = prompts::system_prompt(schema, claude.family()); + info!("model family: {}", claude.family().name()); // Agent state let mut history: Vec = Vec::new(); @@ -267,10 +268,11 @@ pub async fn run(cli: &Cli) -> Result<()> { let strategy = match claude::extract_json(&response_text) { Ok(s) => s, Err(e) => { - warn!("failed to extract strategy JSON: {e}"); + warn!("failed to extract strategy JSON: {e:#}"); warn!( - "raw response: {}", - &response_text[..response_text.len().min(500)] + "raw response ({} chars): {}", + response_text.len(), + &response_text[..response_text.len().min(800)] ); consecutive_failures += 1; if consecutive_failures >= 3 { diff --git a/src/claude.rs b/src/claude.rs index 37381cc..8b96a77 100644 --- a/src/claude.rs +++ b/src/claude.rs @@ -3,11 +3,14 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::Value; +use crate::config::ModelFamily; + pub struct ClaudeClient { client: Client, api_key: String, api_url: String, model: String, + family: ModelFamily, } #[derive(Serialize)] @@ -43,8 +46,11 @@ pub struct Usage { impl ClaudeClient { pub fn new(api_key: &str, api_url: &str, model: &str) -> Self { + let family = ModelFamily::detect(model); + // R1 thinking can take several minutes; use a generous timeout. + let timeout_secs = if family.has_thinking() { 300 } else { 120 }; let client = Client::builder() - .timeout(std::time::Duration::from_secs(120)) + .timeout(std::time::Duration::from_secs(timeout_secs)) .build() .expect("build http client"); Self { @@ -52,9 +58,14 @@ impl ClaudeClient { api_key: api_key.to_string(), api_url: api_url.to_string(), model: model.to_string(), + family, } } + pub fn family(&self) -> &ModelFamily { + &self.family + } + /// Send a conversation to Claude and get the text response. pub async fn chat( &self, @@ -63,7 +74,7 @@ impl ClaudeClient { ) -> Result<(String, Option)> { let body = MessagesRequest { model: self.model.clone(), - max_tokens: 8192, + max_tokens: self.family.max_output_tokens(), system: system.to_string(), messages: messages.to_vec(), }; diff --git a/src/config.rs b/src/config.rs index d830b40..a9ad403 100644 --- a/src/config.rs +++ b/src/config.rs @@ -2,6 +2,50 @@ use std::path::PathBuf; use clap::Parser; +/// Model family — controls token budgets and prompt style. +#[derive(Debug, Clone, PartialEq)] +pub enum ModelFamily { + /// DeepSeek-R1 and its distillations: emit `` blocks that count + /// against the output-token budget, so we need a much larger max_tokens. + DeepSeekR1, + /// General instruction-following models (Qwen, Llama, Mistral, …). + Generic, +} + +impl ModelFamily { + /// Detect family from a model name string (case-insensitive). + pub fn detect(model: &str) -> Self { + let m = model.to_ascii_lowercase(); + if m.contains("deepseek-r1") || m.contains("r1-distill") || m.contains("r1_distill") { + Self::DeepSeekR1 + } else { + Self::Generic + } + } + + /// Display name for logging. + pub fn name(&self) -> &'static str { + match self { + Self::DeepSeekR1 => "DeepSeek-R1", + Self::Generic => "Generic", + } + } + + /// Maximum output tokens to request. R1 thinking blocks can be thousands + /// of tokens; reserve enough headroom for the JSON after thinking. + pub fn max_output_tokens(&self) -> u32 { + match self { + Self::DeepSeekR1 => 32768, + Self::Generic => 8192, + } + } + + /// Whether this model family emits chain-of-thought before its response. + pub fn has_thinking(&self) -> bool { + matches!(self, Self::DeepSeekR1) + } +} + /// Autonomous strategy search agent for the swym backtesting platform. /// /// Runs a loop: ask Claude to generate/refine strategies → submit backtests to swym → diff --git a/src/prompts.rs b/src/prompts.rs index 7826b9c..ce26576 100644 --- a/src/prompts.rs +++ b/src/prompts.rs @@ -1,9 +1,28 @@ -/// System prompt for the strategy-generation Claude instance. +use crate::config::ModelFamily; + +/// System prompt for the strategy-generation model. /// -/// This is the most important part of the agent — it defines how Claude -/// thinks about strategy design, what it knows about the DSL, and how -/// it should interpret backtest results. -pub fn system_prompt(dsl_schema: &str) -> String { +/// Accepts a `ModelFamily` so each family can receive tailored guidance +/// while sharing the common DSL schema and strategy evaluation rules. +pub fn system_prompt(dsl_schema: &str, family: &ModelFamily) -> String { + let output_instructions = match family { + ModelFamily::DeepSeekR1 => { + "## Output format\n\n\ + Think through your strategy design carefully before committing to it. \ + After your thinking, output ONLY a bare JSON object — no markdown fences, \ + no commentary, no explanation. Start with `{` and end with `}`. \ + Your thinking will be stripped automatically; only the JSON is used." + } + ModelFamily::Generic => { + "## How to respond\n\n\ + You must respond with ONLY a valid JSON object — the strategy config.\n\ + No prose, no markdown explanation, no commentary.\n\ + Just the raw JSON starting with { and ending with }.\n\n\ + The JSON must be a valid strategy with \"type\": \"rule_based\".\n\ + Use \"usdc\" (not \"usdt\") as the quote asset for balance expressions." + } + }; + format!( r##"You are a quantitative trading strategy researcher. Your task is to design, evaluate, and iteratively refine trading strategies expressed in the swym JSON DSL. @@ -88,14 +107,7 @@ Every strategy MUST have: - A time-based exit: use bars_since_entry to avoid holding losers indefinitely - Reasonable position sizing: prefer ATR-based or percent-of-balance over fixed quantity -## How to respond - -You must respond with ONLY a valid JSON object — the strategy config. -No prose, no markdown explanation, no commentary. -Just the raw JSON starting with {{ and ending with }}. - -The JSON must be a valid strategy with "type": "rule_based". -Use "usdc" (not "usdt") as the quote asset for balance expressions. +{output_instructions} ## Interpreting backtest results