trtr

Trading simulator and techanalysis gym
git clone https://git.ea.contact/trtr
Log | Files | Refs | README

handlers.rs (7019B)


      1 use std::sync::Arc;
      2 
      3 use axum::{
      4     Json,
      5     extract::State,
      6     http::StatusCode,
      7     response::{Html, IntoResponse},
      8 };
      9 use dashmap::DashMap;
     10 use rand::{Rng, rng};
     11 use serde::{Deserialize, Serialize};
     12 use uuid::Uuid;
     13 
     14 use crate::{data::AppData, session::{self, ActivePrediction, Direction}};
     15 
     16 #[derive(Clone)]
     17 pub struct AppState {
     18     pub data: Arc<AppData>,
     19     pub sessions: Arc<DashMap<String, session::SessionState>>,
     20 }
     21 
     22 // ── helpers ───────────────────────────────────────────────────────────────────
     23 
     24 fn is_leap(y: i32) -> bool {
     25     (y % 4 == 0 && y % 100 != 0) || y % 400 == 0
     26 }
     27 
     28 fn month_days(y: i32, m: u32) -> i64 {
     29     match m {
     30         1 | 3 | 5 | 7 | 8 | 10 | 12 => 31,
     31         4 | 6 | 9 | 11 => 30,
     32         2 => if is_leap(y) { 29 } else { 28 },
     33         _ => 30,
     34     }
     35 }
     36 
     37 fn period_start_ts(year: i32, month: u32) -> i64 {
     38     let mut days: i64 = 0;
     39     for y in 1970..year {
     40         days += if is_leap(y) { 366 } else { 365 };
     41     }
     42     for m in 1..month {
     43         days += month_days(year, m);
     44     }
     45     days * 86400
     46 }
     47 
     48 fn candle_range(candles: &[crate::data::Candle], year: i32, month: Option<u32>) -> (usize, usize) {
     49     let lo_ts = period_start_ts(year, month.unwrap_or(1));
     50     let hi_ts = match month {
     51         Some(12) | None => period_start_ts(year + 1, 1),
     52         Some(m) => period_start_ts(year, m + 1),
     53     };
     54     let lo = candles.partition_point(|c| c.ts < lo_ts);
     55     let hi = candles.partition_point(|c| c.ts < hi_ts);
     56     (lo, hi)
     57 }
     58 
     59 // ── /api/session/new ──────────────────────────────────────────────────────────
     60 
     61 #[derive(Deserialize, Default)]
     62 pub struct NewSessionRequest {
     63     pub year: Option<i32>,
     64     pub month: Option<u32>,
     65 }
     66 
     67 #[derive(Serialize)]
     68 pub struct NewSessionResponse {
     69     session_id: String,
     70     candles: Vec<crate::data::Candle>,
     71 }
     72 
     73 pub async fn new_session(
     74     State(state): State<AppState>,
     75     Json(req): Json<NewSessionRequest>,
     76 ) -> impl IntoResponse {
     77     let total = state.data.candles.len();
     78     let mut rng = rng();
     79 
     80     let start = match req.year {
     81         None => session::random_start_index(total, &mut rng),
     82         Some(year) => {
     83             let (lo, hi) = candle_range(&state.data.candles, year, req.month);
     84             if hi.saturating_sub(lo) < 100 {
     85                 return (
     86                     StatusCode::BAD_REQUEST,
     87                     Json(serde_json::json!({"error": "not_enough_data"})),
     88                 ).into_response();
     89             }
     90             let max = hi.saturating_sub(100).min(total.saturating_sub(5100)).max(lo);
     91             rng.random_range(lo..=max)
     92         }
     93     };
     94 
     95     let s = session::SessionState::new(start);
     96     let candles = state.data.candles[start..=s.current_index].to_vec();
     97     let id = Uuid::new_v4().to_string();
     98     state.sessions.insert(id.clone(), s);
     99 
    100     Json(NewSessionResponse { session_id: id, candles }).into_response()
    101 }
    102 
    103 // ── /api/next ─────────────────────────────────────────────────────────────────
    104 
    105 #[derive(Deserialize)]
    106 pub struct SessionIdBody {
    107     pub session_id: String,
    108 }
    109 
    110 #[derive(Serialize)]
    111 pub struct ResolvedPrediction {
    112     entry_price: f64,
    113     close_price: f64,
    114     direction: Direction,
    115     fee_pct: f64,
    116     profit: f64,
    117 }
    118 
    119 #[derive(Serialize)]
    120 pub struct NextCandleResponse {
    121     candle: crate::data::Candle,
    122     resolved: Option<ResolvedPrediction>,
    123 }
    124 
    125 pub async fn next_candle(
    126     State(state): State<AppState>,
    127     Json(body): Json<SessionIdBody>,
    128 ) -> impl IntoResponse {
    129     let Some(mut s) = state.sessions.get_mut(&body.session_id) else {
    130         return (StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "session_not_found"}))).into_response();
    131     };
    132 
    133     let next = s.current_index + 1;
    134     if next >= state.data.candles.len() {
    135         return (StatusCode::CONFLICT, Json(serde_json::json!({"error": "end_of_data"}))).into_response();
    136     }
    137 
    138     s.current_index = next;
    139     let candle = state.data.candles[next].clone();
    140 
    141     let resolved = if s.prediction.as_ref().map_or(false, |p| next >= p.target_index) {
    142         let pred = s.prediction.take().unwrap();
    143         let profit = session::resolve_prediction(&pred, candle.close);
    144         Some(ResolvedPrediction {
    145             entry_price: pred.entry_price,
    146             close_price: candle.close,
    147             direction: pred.direction,
    148             fee_pct: pred.fee_pct,
    149             profit,
    150         })
    151     } else {
    152         None
    153     };
    154     drop(s);
    155 
    156     Json(NextCandleResponse { candle, resolved }).into_response()
    157 }
    158 
    159 // ── /api/predict ──────────────────────────────────────────────────────────────
    160 
    161 #[derive(Deserialize)]
    162 pub struct PredictBody {
    163     session_id: String,
    164     bars_ahead: u32,
    165     direction: Direction,
    166     fee_pct: f64,
    167 }
    168 
    169 #[derive(Serialize)]
    170 pub struct PredictResponse {
    171     entry_price: f64,
    172 }
    173 
    174 pub async fn predict(
    175     State(state): State<AppState>,
    176     Json(body): Json<PredictBody>,
    177 ) -> impl IntoResponse {
    178     if body.bars_ahead == 0 {
    179         return (StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": "bars_ahead_must_be_positive"}))).into_response();
    180     }
    181 
    182     let Some(mut s) = state.sessions.get_mut(&body.session_id) else {
    183         return (StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "session_not_found"}))).into_response();
    184     };
    185 
    186     let entry_price = state.data.candles[s.current_index].close;
    187     s.prediction = Some(ActivePrediction {
    188         target_index: s.current_index + body.bars_ahead as usize,
    189         entry_price,
    190         direction: body.direction,
    191         fee_pct: body.fee_pct.clamp(0.0, 100.0),
    192     });
    193     drop(s);
    194 
    195     Json(PredictResponse { entry_price }).into_response()
    196 }
    197 
    198 // ── /api/predict/cancel ───────────────────────────────────────────────────────
    199 
    200 pub async fn cancel_predict(
    201     State(state): State<AppState>,
    202     Json(body): Json<SessionIdBody>,
    203 ) -> impl IntoResponse {
    204     let Some(mut s) = state.sessions.get_mut(&body.session_id) else {
    205         return (StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "session_not_found"}))).into_response();
    206     };
    207     s.prediction = None;
    208     drop(s);
    209     Json(serde_json::json!({})).into_response()
    210 }
    211 
    212 // ── / ─────────────────────────────────────────────────────────────────────────
    213 
    214 pub async fn serve_html() -> impl IntoResponse {
    215     Html(include_str!("../frontend/index.html"))
    216 }