from __future__ import annotations import asyncio import os import sqlite3 from dataclasses import dataclass from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union from aiohttp import web import numpy as np import pandas as pd from statsmodels.tsa.stattools import adfuller, coint # type: ignore from statsmodels.tsa.vector_ar.vecm import coint_johansen # type: ignore from cvttpy_tools.app import App from cvttpy_tools.base import NamedObject from cvttpy_tools.config import Config, CvttAppConfig from cvttpy_tools.logger import Log from cvttpy_tools.timeutils import NanoPerSec, SecPerHour, current_nanoseconds from cvttpy_tools.web.rest_client import RESTSender from cvttpy_tools.web.rest_service import RestService from cvttpy_trading.trading.exchange_config import ExchangeAccounts from cvttpy_trading.trading.instrument import ExchangeInstrument from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate, MdSummary from pairs_trading.apps.pair_selector.renderer import HtmlRenderer @dataclass class BacktestAggregate: aggr_time_ns_: int num_trades_: Optional[int] @dataclass class InstrumentQuality(NamedObject): instrument_: ExchangeInstrument record_count_: int latest_tstamp_: Optional[pd.Timestamp] status_: str reason_: str @dataclass class PairStats(NamedObject): pair_name_: str instrument_a_: ExchangeInstrument instrument_b_: ExchangeInstrument pvalue_eg_: Optional[float] pvalue_adf_: Optional[float] pvalue_j_: Optional[float] trace_stat_j_: Optional[float] rank_eg_: int = 0 rank_adf_: int = 0 rank_j_: int = 0 composite_rank_: int = 0 def as_dict(self) -> Dict[str, Any]: return { "exchange_a": self.instrument_a_.exchange_id_, "exchange_b": self.instrument_b_.exchange_id_, "pair_name": self.pair_name_, "instrument_a": self.instrument_a_.instrument_id(), "instrument_b": self.instrument_b_.instrument_id(), "pvalue_eg": self.pvalue_eg_, "pvalue_adf": self.pvalue_adf_, "pvalue_j": self.pvalue_j_, "trace_stat_j": self.trace_stat_j_, "rank_eg": self.rank_eg_, "rank_adf": self.rank_adf_, "rank_j": self.rank_j_, "composite_rank": self.composite_rank_, } def _extract_price_from_fields( price_field: str, inst: ExchangeInstrument, open: Optional[float], high: Optional[float], low: Optional[float], close: Optional[float], vwap: Optional[float], ) -> float: field_map = { "open": open, "high": high, "low": low, "close": close, "vwap": vwap, } raw = field_map.get(price_field, close) if raw is None: raw = 0.0 return inst.get_price(raw) class DataFetcher(NamedObject): sender_: RESTSender interval_sec_: int history_depth_sec_: int def __init__( self, base_url: str, interval_sec: int, history_depth_sec: int, ) -> None: self.sender_ = RESTSender(base_url=base_url) self.interval_sec_ = interval_sec self.history_depth_sec_ = history_depth_sec def fetch( self, exch_acct: str, inst: ExchangeInstrument ) -> List[MdTradesAggregate]: rqst_data = { "exch_acct": exch_acct, "instrument_id": inst.instrument_id(), "interval_sec": self.interval_sec_, "history_depth_sec": self.history_depth_sec_, } response = self.sender_.send_post(endpoint="md_summary", post_body=rqst_data) if response.status_code not in (200, 201): Log.error( f"{self.fname()}: error {response.status_code} for {inst.details_short()}: {response.text}" ) return [] mdsums: List[MdSummary] = MdSummary.from_REST_response(response=response) return [ mdsum.create_md_trades_aggregate( exch_acct=exch_acct, exch_inst=inst, interval_sec=self.interval_sec_ ) for mdsum in mdsums ] AggregateLike = Union[MdTradesAggregate, BacktestAggregate] class QualityChecker(NamedObject): interval_sec_: int def __init__(self, interval_sec: int) -> None: self.interval_sec_ = interval_sec def evaluate( self, inst: ExchangeInstrument, aggr: Sequence[AggregateLike], now_ts: Optional[pd.Timestamp] = None, ) -> InstrumentQuality: if len(aggr) == 0: return InstrumentQuality( instrument_=inst, record_count_=0, latest_tstamp_=None, status_="FAIL", reason_="no records", ) aggr_sorted = sorted(aggr, key=lambda a: a.aggr_time_ns_) latest_ts = pd.to_datetime(aggr_sorted[-1].aggr_time_ns_, unit="ns", utc=True) now_ts = now_ts or pd.Timestamp.utcnow() recency_cutoff = now_ts - pd.Timedelta(seconds=2 * self.interval_sec_) if latest_ts <= recency_cutoff: return InstrumentQuality( instrument_=inst, record_count_=len(aggr_sorted), latest_tstamp_=latest_ts, status_="FAIL", reason_=f"stale: latest {latest_ts} <= cutoff {recency_cutoff}", ) gaps_ok, reason = self._check_gaps(aggr_sorted) status = "PASS" if gaps_ok else "FAIL" return InstrumentQuality( instrument_=inst, record_count_=len(aggr_sorted), latest_tstamp_=latest_ts, status_=status, reason_=reason, ) def _check_gaps(self, aggr: Sequence[AggregateLike]) -> Tuple[bool, str]: NUM_TRADES_THRESHOLD = 50 if len(aggr) < 2: return True, "ok" interval_ns = self.interval_sec_ * NanoPerSec for idx in range(1, len(aggr)): prev = aggr[idx - 1] curr = aggr[idx] delta = curr.aggr_time_ns_ - prev.aggr_time_ns_ missing_intervals = int(delta // interval_ns) - 1 if missing_intervals <= 0: continue prev_nt = prev.num_trades_ next_nt = curr.num_trades_ estimate = self._approximate_num_trades(prev_nt, next_nt) if estimate > NUM_TRADES_THRESHOLD: return False, ( f"gap of {missing_intervals} interval(s), est num_trades={estimate} > {NUM_TRADES_THRESHOLD}" ) return True, "ok" @staticmethod def _approximate_num_trades(prev_nt: Optional[int], next_nt: Optional[int]) -> float: if prev_nt is None and next_nt is None: return 0.0 if prev_nt is None: return float(next_nt or 0) if next_nt is None: return float(prev_nt) return (prev_nt + next_nt) / 2.0 class PairAnalyzer(NamedObject): price_field_: str interval_sec_: int def __init__(self, price_field: str, interval_sec: int) -> None: self.price_field_ = price_field self.interval_sec_ = interval_sec def analyze( self, series: Dict[ExchangeInstrument, pd.DataFrame] ) -> Dict[str, PairStats]: instruments = list(series.keys()) results: Dict[str, PairStats] = {} for i in range(len(instruments)): for j in range(i + 1, len(instruments)): inst_a, inst_b, pair_name = self._normalized_pair( instruments[i], instruments[j] ) df_a = series[inst_a][["tstamp", "price"]].rename( columns={"price": "price_a"} ) df_b = series[inst_b][["tstamp", "price"]].rename( columns={"price": "price_b"} ) merged = pd.merge(df_a, df_b, on="tstamp", how="inner").sort_values( "tstamp" ) # Log.info(f"{self.fname()}: analyzing {pair_name}") stats = self._compute_stats(inst_a, inst_b, pair_name, merged) if stats: results[pair_name] = stats return self._rank(results) def _compute_stats( self, inst_a: ExchangeInstrument, inst_b: ExchangeInstrument, pair_name: str, merged: pd.DataFrame, ) -> Optional[PairStats]: if len(merged) < 2: return None px_a = merged["price_a"].astype(float) px_b = merged["price_b"].astype(float) std_a = float(px_a.std()) std_b = float(px_b.std()) if std_a == 0 or std_b == 0: return None z_a = (px_a - float(px_a.mean())) / std_a z_b = (px_b - float(px_b.mean())) / std_b p_eg: Optional[float] p_adf: Optional[float] p_j: Optional[float] trace_stat: Optional[float] try: p_eg = float(coint(z_a, z_b)[1]) except Exception as exc: Log.warning( f"{self.fname()}: EG failed for {inst_a.details_short()}/{inst_b.details_short()}: {exc}" ) p_eg = None try: spread = z_a - z_b p_adf = float(adfuller(spread, maxlag=1, regression="c")[1]) except Exception as exc: Log.warning( f"{self.fname()}: ADF failed for {inst_a.details_short()}/{inst_b.details_short()}: {exc}" ) p_adf = None try: data = np.column_stack([z_a, z_b]) res = coint_johansen(data, det_order=0, k_ar_diff=1) trace_stat = float(res.lr1[0]) cv10, cv5, cv1 = res.cvt[0] if trace_stat > cv1: p_j = 0.01 elif trace_stat > cv5: p_j = 0.05 elif trace_stat > cv10: p_j = 0.10 else: p_j = 1.0 except Exception as exc: Log.warning( f"{self.fname()}: Johansen failed for {inst_a.details_short()}/{inst_b.details_short()}: {exc}" ) p_j = None trace_stat = None return PairStats( pair_name_=pair_name, instrument_a_=inst_a, instrument_b_=inst_b, pvalue_eg_=p_eg, pvalue_adf_=p_adf, pvalue_j_=p_j, trace_stat_j_=trace_stat, ) def _rank(self, results: Dict[str, PairStats]) -> Dict[str, PairStats]: ranked = list(results.values()) self._assign_ranks(ranked, key=lambda r: r.pvalue_eg_, attr="rank_eg_") self._assign_ranks(ranked, key=lambda r: r.pvalue_adf_, attr="rank_adf_") self._assign_ranks(ranked, key=lambda r: r.pvalue_j_, attr="rank_j_") for res in ranked: res.composite_rank_ = res.rank_eg_ + res.rank_adf_ # + res.rank_j_ ranked.sort(key=lambda r: r.composite_rank_) return {res.pair_name_: res for res in ranked} @staticmethod def _normalized_pair( inst_a: ExchangeInstrument, inst_b: ExchangeInstrument ) -> Tuple[ExchangeInstrument, ExchangeInstrument, str]: inst_a_id = PairAnalyzer._pair_label(inst_a.instrument_id()) inst_b_id = PairAnalyzer._pair_label(inst_b.instrument_id()) if inst_a_id <= inst_b_id: return inst_a, inst_b, f"{inst_a_id}<->{inst_b_id}" return inst_b, inst_a, f"{inst_b_id}<->{inst_a_id}" @staticmethod def _pair_label(instrument_id: str) -> str: if instrument_id.startswith("PAIR-"): return instrument_id[len("PAIR-") :] return instrument_id @staticmethod def _assign_ranks(results: List[PairStats], key, attr: str) -> None: values = [key(r) for r in results] sorted_vals = sorted([v for v in values if v is not None]) for res in results: val = key(res) if val is None: setattr(res, attr, len(sorted_vals) + 1) continue rank = 1 + sum(1 for v in sorted_vals if v < val) setattr(res, attr, rank) class PairSelectionEngine(NamedObject): config_: object instruments_: List[ExchangeInstrument] price_field_: str fetcher_: DataFetcher quality_: QualityChecker analyzer_: PairAnalyzer interval_sec_: int history_depth_sec_: int data_quality_cache_: List[InstrumentQuality] pair_results_cache_: Dict[str, PairStats] def __init__( self, config: Config, instruments: List[ExchangeInstrument], price_field: str, ) -> None: self.config_ = config self.instruments_ = instruments self.price_field_ = price_field interval_sec = int(config.get_value("interval_sec", 0)) history_depth_sec = int(config.get_value("history_depth_hours", 0)) * SecPerHour base_url = config.get_value("cvtt_base_url", None) assert interval_sec > 0, "interval_sec must be > 0" assert history_depth_sec > 0, "history_depth_sec must be > 0" assert base_url, "cvtt_base_url must be set" self.fetcher_ = DataFetcher( base_url=base_url, interval_sec=interval_sec, history_depth_sec=history_depth_sec, ) self.quality_ = QualityChecker(interval_sec=interval_sec) self.analyzer_ = PairAnalyzer( price_field=price_field, interval_sec=interval_sec ) self.interval_sec_ = interval_sec self.history_depth_sec_ = history_depth_sec self.data_quality_cache_ = [] self.pair_results_cache_ = {} async def run_once(self) -> None: quality_results: List[InstrumentQuality] = [] price_series: Dict[ExchangeInstrument, pd.DataFrame] = {} for inst in self.instruments_: exch_acct = inst.user_data_.get("exch_acct") or inst.exchange_id_ aggr = self.fetcher_.fetch(exch_acct=exch_acct, inst=inst) q = self.quality_.evaluate(inst, aggr) quality_results.append(q) if q.status_ != "PASS": continue df = self._to_dataframe(aggr, inst) if len(df) > 0: price_series[inst] = df self.data_quality_cache_ = quality_results self.pair_results_cache_ = self.analyzer_.analyze(price_series) def _to_dataframe( self, aggr: List[MdTradesAggregate], inst: ExchangeInstrument ) -> pd.DataFrame: rows: List[Dict[str, Any]] = [] for item in aggr: rows.append( { "tstamp": pd.to_datetime(item.aggr_time_ns_, unit="ns", utc=True), "price": self._extract_price(item, inst), "num_trades": item.num_trades_, } ) df = pd.DataFrame(rows) return df.sort_values("tstamp").reset_index(drop=True) def _extract_price( self, aggr: MdTradesAggregate, inst: ExchangeInstrument ) -> float: return _extract_price_from_fields( price_field=self.price_field_, inst=inst, open=aggr.open_, high=aggr.high_, low=aggr.low_, close=aggr.close_, vwap=aggr.vwap_, ) def sleep_seconds_until_next_cycle(self) -> float: now_ns = current_nanoseconds() interval_ns = self.interval_sec_ * NanoPerSec next_boundary = (now_ns // interval_ns + 1) * interval_ns return max(0.0, (next_boundary - now_ns) / NanoPerSec) def quality_dicts(self) -> List[Dict[str, Any]]: res: List[Dict[str, Any]] = [] for q in self.data_quality_cache_: res.append( { "instrument": q.instrument_.instrument_id(), "record_count": q.record_count_, "latest_tstamp": ( q.latest_tstamp_.isoformat() if q.latest_tstamp_ else None ), "status": q.status_, "reason": q.reason_, } ) return res def pair_dicts(self) -> Dict[str, Dict[str, Any]]: return { pair_name: stats.as_dict() for pair_name, stats in self.pair_results_cache_.items() } class PairSelectionBacktest(NamedObject): config_: object instruments_: List[ExchangeInstrument] price_field_: str input_db_: str output_db_: str interval_sec_: int history_depth_hours_: int quality_: QualityChecker analyzer_: PairAnalyzer inst_by_key_: Dict[Tuple[str, str], ExchangeInstrument] inst_by_id_: Dict[str, Optional[ExchangeInstrument]] ambiguous_ids_: Set[str] def __init__( self, config: Config, instruments: List[ExchangeInstrument], price_field: str, input_db: str, output_db: str, ) -> None: self.config_ = config self.instruments_ = instruments self.price_field_ = price_field self.input_db_ = input_db self.output_db_ = output_db interval_sec = int(config.get_value("interval_sec", 0)) if interval_sec <= 0: Log.warning( f"{self.fname()}: interval_sec not set; defaulting to 60 seconds" ) interval_sec = 60 history_depth_hours = int(config.get_value("history_depth_hours", 0)) assert history_depth_hours > 0, "history_depth_hours must be > 0" self.interval_sec_ = interval_sec self.history_depth_hours_ = history_depth_hours self.quality_ = QualityChecker(interval_sec=interval_sec) self.analyzer_ = PairAnalyzer( price_field=price_field, interval_sec=interval_sec ) self.inst_by_key_ = { (inst.exchange_id_, inst.instrument_id()): inst for inst in instruments } self.inst_by_id_ = {} self.ambiguous_ids_ = set() for inst in instruments: inst_id = inst.instrument_id() if inst_id in self.inst_by_id_: existing = self.inst_by_id_[inst_id] if existing is not None and existing.exchange_id_ != inst.exchange_id_: self.inst_by_id_[inst_id] = None self.ambiguous_ids_.add(inst_id) elif inst_id not in self.ambiguous_ids_: self.inst_by_id_[inst_id] = inst if self.ambiguous_ids_: Log.warning( f"{self.fname()}: ambiguous instrument_id(s) without exchange_id: " f"{sorted(self.ambiguous_ids_)}" ) def run(self) -> None: df = self._load_input_df() if df.empty: Log.warning(f"{self.fname()}: no rows in md_1min_bars") return df = self._filter_instruments(df) if df.empty: Log.warning(f"{self.fname()}: no rows after instrument filtering") return conn = self._init_output_db() try: self._run_backtest(df, conn) finally: conn.commit() conn.close() def _load_input_df(self) -> pd.DataFrame: if not os.path.exists(self.input_db_): raise FileNotFoundError(f"input_db not found: {self.input_db_}") with sqlite3.connect(self.input_db_) as conn: df = pd.read_sql_query( """ SELECT tstamp, tstamp_ns, exchange_id, instrument_id, open, high, low, close, volume, vwap, num_trades FROM md_1min_bars """, conn, ) if df.empty: return df ts_ns = pd.to_datetime(df["tstamp_ns"], unit="ns", utc=True, errors="coerce") ts_txt = pd.to_datetime(df["tstamp"], utc=True, errors="coerce") df["tstamp"] = ts_ns.fillna(ts_txt) df = df.dropna(subset=["tstamp", "instrument_id"]).copy() df["exchange_id"] = df["exchange_id"].fillna("") df["instrument_id"] = df["instrument_id"].astype(str) df["tstamp_ns"] = df["tstamp"].astype("int64") return df.sort_values("tstamp").reset_index(drop=True) def _filter_instruments(self, df: pd.DataFrame) -> pd.DataFrame: instrument_ids = {inst.instrument_id() for inst in self.instruments_} df = df[df["instrument_id"].isin(instrument_ids)].copy() if "exchange_id" in df.columns: exchange_ids = {inst.exchange_id_ for inst in self.instruments_} df = df[ (df["exchange_id"].isin(exchange_ids)) | (df["exchange_id"] == "") ].copy() return df def _init_output_db(self) -> sqlite3.Connection: if os.path.exists(self.output_db_): os.remove(self.output_db_) conn = sqlite3.connect(self.output_db_) conn.execute( """ CREATE TABLE pair_selection_history ( tstamp TEXT, tstamp_ns INTEGER, pair_name TEXT, exchange_a TEXT, instrument_a TEXT, exchange_b TEXT, instrument_b TEXT, pvalue_eg REAL, pvalue_adf REAL, pvalue_j REAL, trace_stat_j REAL, rank_eg INTEGER, rank_adf INTEGER, rank_j INTEGER, composite_rank REAL ) """ ) conn.execute( """ CREATE INDEX idx_pair_selection_history_pair_name ON pair_selection_history (pair_name) """ ) conn.execute( """ CREATE UNIQUE INDEX idx_pair_selection_history_tstamp_pair ON pair_selection_history (tstamp, pair_name) """ ) conn.commit() return conn def _resolve_instrument( self, exchange_id: str, instrument_id: str ) -> Optional[ExchangeInstrument]: if exchange_id: inst = self.inst_by_key_.get((exchange_id, instrument_id)) if inst is not None: return inst inst = self.inst_by_id_.get(instrument_id) if inst is None and instrument_id in self.ambiguous_ids_: return None return inst def _build_day_series( self, df_day: pd.DataFrame ) -> Dict[ExchangeInstrument, pd.DataFrame]: series: Dict[ExchangeInstrument, pd.DataFrame] = {} group_cols = ["exchange_id", "instrument_id"] for key, group in df_day.groupby(group_cols, dropna=False): exchange_id, instrument_id = key inst = self._resolve_instrument(str(exchange_id or ""), str(instrument_id)) if inst is None: continue df_inst = group.copy() df_inst["price"] = [ _extract_price_from_fields( price_field=self.price_field_, inst=inst, open=float(row.open), #type: ignore high=float(row.high), #type: ignore low=float(row.low), #type: ignore close=float(row.close), #type: ignore vwap=float(row.vwap),#type: ignore ) for row in df_inst.itertuples(index=False) ] df_inst = df_inst[["tstamp", "tstamp_ns", "price", "num_trades"]] if inst in series: series[inst] = pd.concat([series[inst], df_inst], ignore_index=True) else: series[inst] = df_inst for inst in list(series.keys()): series[inst] = series[inst].sort_values("tstamp").reset_index(drop=True) return series def _run_backtest(self, df: pd.DataFrame, conn: sqlite3.Connection) -> None: window_minutes = self.history_depth_hours_ * 60 window_td = pd.Timedelta(minutes=window_minutes) step_td = pd.Timedelta(seconds=self.interval_sec_) df = df.copy() df["day"] = df["tstamp"].dt.normalize() days = sorted(df["day"].unique()) for day in days: day_label = pd.Timestamp(day).date() df_day = df[df["day"] == day] t0 = df_day["tstamp"].min() t_last = df_day["tstamp"].max() if t_last - t0 < window_td: Log.warning( f"{self.fname()}: skipping {day_label} (insufficient data)" ) continue day_series = self._build_day_series(df_day) if len(day_series) < 2: Log.warning( f"{self.fname()}: skipping {day_label} (insufficient instruments)" ) continue start = t0 expected_end = start + window_td while expected_end <= t_last: window_slices: Dict[ExchangeInstrument, pd.DataFrame] = {} ts: Optional[pd.Timestamp] = None for inst, df_inst in day_series.items(): df_win = df_inst[ (df_inst["tstamp"] >= start) & (df_inst["tstamp"] < expected_end) ] if df_win.empty: continue window_slices[inst] = df_win last_ts = df_win["tstamp"].iloc[-1] if ts is None or last_ts > ts: ts = last_ts if window_slices and ts is not None: price_series: Dict[ExchangeInstrument, pd.DataFrame] = {} for inst, df_win in window_slices.items(): aggr = self._to_backtest_aggregates(df_win) q = self.quality_.evaluate( inst=inst, aggr=aggr, now_ts=ts ) if q.status_ != "PASS": continue price_series[inst] = df_win[["tstamp", "price"]] pair_results = self.analyzer_.analyze(price_series) Log.info(f"{self.fname()}: Saving Results for window ending {ts}") self._insert_results(conn, ts, pair_results) start = start + step_td expected_end = start + window_td @staticmethod def _to_backtest_aggregates(df_win: pd.DataFrame) -> List[BacktestAggregate]: aggr: List[BacktestAggregate] = [] for tstamp_ns, num_trades in zip(df_win["tstamp_ns"], df_win["num_trades"]): nt = None if pd.isna(num_trades) else int(num_trades) aggr.append( BacktestAggregate(aggr_time_ns_=int(tstamp_ns), num_trades_=nt) ) return aggr @staticmethod def _insert_results( conn: sqlite3.Connection, ts: pd.Timestamp, pair_results: Dict[str, PairStats], ) -> None: if not pair_results: return iso = ts.isoformat() ns = int(ts.value) rows = [] for pair_name in sorted(pair_results.keys()): stats = pair_results[pair_name] rows.append( ( iso, ns, pair_name, stats.instrument_a_.exchange_id_, stats.instrument_a_.instrument_id(), stats.instrument_b_.exchange_id_, stats.instrument_b_.instrument_id(), stats.pvalue_eg_, stats.pvalue_adf_, stats.pvalue_j_, stats.trace_stat_j_, stats.rank_eg_, stats.rank_adf_, stats.rank_j_, stats.composite_rank_, ) ) conn.executemany( """ INSERT INTO pair_selection_history ( tstamp, tstamp_ns, pair_name, exchange_a, instrument_a, exchange_b, instrument_b, pvalue_eg, pvalue_adf, pvalue_j, trace_stat_j, rank_eg, rank_adf, rank_j, composite_rank ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, rows, ) conn.commit() class PairSelector(NamedObject): instruments_: List[ExchangeInstrument] engine_: PairSelectionEngine rest_service_: Optional[RestService] backtest_: Optional[PairSelectionBacktest] def __init__(self) -> None: App.instance().add_cmdline_arg("--oneshot", action="store_true", default=False) App.instance().add_cmdline_arg("--backtest", action="store_true", default=False) App.instance().add_cmdline_arg("--input_db", default=None) App.instance().add_cmdline_arg("--output_db", default=None) App.instance().add_call(App.Stage.Config, self._on_config()) App.instance().add_call(App.Stage.Run, self.run()) async def _on_config(self) -> None: cfg = CvttAppConfig.instance() self.instruments_ = self._load_instruments(cfg) price_field = cfg.get_value("model/stat_model_price", "close") self.backtest_ = None self.rest_service_ = None if App.instance().get_argument("backtest", False): input_db = App.instance().get_argument("input_db", None) output_db = App.instance().get_argument("output_db", None) if not input_db or not output_db: raise ValueError( "--input_db and --output_db are required when --backtest is set" ) self.backtest_ = PairSelectionBacktest( config=cfg, instruments=self.instruments_, price_field=price_field, input_db=input_db, output_db=output_db, ) return self.engine_ = PairSelectionEngine( config=cfg, instruments=self.instruments_, price_field=price_field, ) self.rest_service_ = RestService(config_key="/api/REST") self.rest_service_.add_handler("GET", "/data_quality", self._on_data_quality) self.rest_service_.add_handler( "GET", "/pair_selection", self._on_pair_selection ) def _load_instruments(self, cfg: CvttAppConfig) -> List[ExchangeInstrument]: instruments_cfg = cfg.get_value("instruments", []) instruments: List[ExchangeInstrument] = [] assert len(instruments_cfg) >= 2, "at least two instruments required" for item in instruments_cfg: if isinstance(item, str): parts = item.split(":", 1) if len(parts) != 2: raise ValueError(f"invalid instrument format: {item}") exch_acct, instrument_id = parts elif isinstance(item, dict): exch_acct = item.get("exch_acct", "") instrument_id = item.get("instrument_id", "") if not exch_acct or not instrument_id: raise ValueError(f"invalid instrument config: {item}") else: raise ValueError(f"unsupported instrument entry: {item}") exch_inst = ExchangeAccounts.instance().get_exchange_instrument( exch_acct=exch_acct, instrument_id=instrument_id ) assert ( exch_inst is not None ), f"no ExchangeInstrument for {exch_acct}:{instrument_id}" exch_inst.user_data_["exch_acct"] = exch_acct instruments.append(exch_inst) return instruments async def run(self) -> None: if App.instance().get_argument("backtest", False): if self.backtest_ is None: raise RuntimeError("backtest runner not initialized") self.backtest_.run() return oneshot = App.instance().get_argument("oneshot", False) while True: await self.engine_.run_once() if oneshot: break sleep_for = self.engine_.sleep_seconds_until_next_cycle() await asyncio.sleep(sleep_for) async def _on_data_quality(self, request: web.Request) -> web.Response: fmt = request.query.get("format", "html").lower() quality = self.engine_.quality_dicts() if fmt == "json": return web.json_response(quality) return web.Response( text=HtmlRenderer.render_data_quality(quality), content_type="text/html" ) async def _on_pair_selection(self, request: web.Request) -> web.Response: fmt = request.query.get("format", "html").lower() pairs = self.engine_.pair_dicts() if fmt == "json": return web.json_response(pairs) return web.Response( text=HtmlRenderer.render_pairs(pairs), content_type="text/html" ) if __name__ == "__main__": App() CvttAppConfig() PairSelector() App.instance().run()