From 98f6defe96b6bbfc4e2dad14c644c7901dc76dbf Mon Sep 17 00:00:00 2001 From: Oleg Sheynin Date: Thu, 5 Feb 2026 04:05:53 +0000 Subject: [PATCH] 0.0.8 --- VERSION | 2 +- apps/pair_selector/pair_selector.py | 442 +++++++++++++++++- apps/pair_selector/pair_selector_engine.py.md | 394 ---------------- research/notebooks/pair_select_hist.ipynb | 329 +++++++++++++ 4 files changed, 753 insertions(+), 414 deletions(-) delete mode 100644 apps/pair_selector/pair_selector_engine.py.md create mode 100644 research/notebooks/pair_select_hist.ipynb diff --git a/VERSION b/VERSION index 5c4511c..7d6b3eb 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.7 \ No newline at end of file +0.0.8 \ No newline at end of file diff --git a/apps/pair_selector/pair_selector.py b/apps/pair_selector/pair_selector.py index 5220e31..1b14896 100644 --- a/apps/pair_selector/pair_selector.py +++ b/apps/pair_selector/pair_selector.py @@ -1,8 +1,10 @@ from __future__ import annotations import asyncio +import os +import sqlite3 from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union from aiohttp import web import numpy as np @@ -26,6 +28,12 @@ from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate, MdSumm from pairs_trading.apps.pair_selector.renderer import HtmlRenderer +@dataclass +class BacktestAggregate: + aggr_time_ns_: int + num_trades_: Optional[int] + + @dataclass class InstrumentQuality(NamedObject): instrument_: ExchangeInstrument @@ -51,6 +59,9 @@ class PairStats(NamedObject): def as_dict(self) -> Dict[str, Any]: return { + "exchange_a": self.instrument_a_.exchange_id_, + "exchange_b": self.instrument_b_.exchange_id_, + "pair_name": self.pair_name_, "instrument_a": self.instrument_a_.instrument_id(), "instrument_b": self.instrument_b_.instrument_id(), "pvalue_eg": self.pvalue_eg_, @@ -64,6 +75,28 @@ class PairStats(NamedObject): } +def _extract_price_from_fields( + price_field: str, + inst: ExchangeInstrument, + open: Optional[float], + high: Optional[float], + low: Optional[float], + close: Optional[float], + vwap: Optional[float], +) -> float: + field_map = { + "open": open, + "high": high, + "low": low, + "close": close, + "vwap": vwap, + } + raw = field_map.get(price_field, close) + if raw is None: + raw = 0.0 + return inst.get_price(raw) + + class DataFetcher(NamedObject): sender_: RESTSender interval_sec_: int @@ -103,6 +136,9 @@ class DataFetcher(NamedObject): ] +AggregateLike = Union[MdTradesAggregate, BacktestAggregate] + + class QualityChecker(NamedObject): interval_sec_: int @@ -110,7 +146,10 @@ class QualityChecker(NamedObject): self.interval_sec_ = interval_sec def evaluate( - self, inst: ExchangeInstrument, aggr: List[MdTradesAggregate] + self, + inst: ExchangeInstrument, + aggr: Sequence[AggregateLike], + now_ts: Optional[pd.Timestamp] = None, ) -> InstrumentQuality: if len(aggr) == 0: return InstrumentQuality( @@ -124,7 +163,7 @@ class QualityChecker(NamedObject): aggr_sorted = sorted(aggr, key=lambda a: a.aggr_time_ns_) latest_ts = pd.to_datetime(aggr_sorted[-1].aggr_time_ns_, unit="ns", utc=True) - now_ts = pd.Timestamp.utcnow() + now_ts = now_ts or pd.Timestamp.utcnow() recency_cutoff = now_ts - pd.Timedelta(seconds=2 * self.interval_sec_) if latest_ts <= recency_cutoff: return InstrumentQuality( @@ -145,7 +184,7 @@ class QualityChecker(NamedObject): reason_=reason, ) - def _check_gaps(self, aggr: List[MdTradesAggregate]) -> Tuple[bool, str]: + def _check_gaps(self, aggr: Sequence[AggregateLike]) -> Tuple[bool, str]: NUM_TRADES_THRESHOLD = 50 if len(aggr) < 2: return True, "ok" @@ -169,11 +208,11 @@ class QualityChecker(NamedObject): return True, "ok" @staticmethod - def _approximate_num_trades(prev_nt: int, next_nt: int) -> float: + def _approximate_num_trades(prev_nt: Optional[int], next_nt: Optional[int]) -> float: if prev_nt is None and next_nt is None: return 0.0 if prev_nt is None: - return float(next_nt) + return float(next_nt or 0) if next_nt is None: return float(prev_nt) return (prev_nt + next_nt) / 2.0 @@ -206,6 +245,7 @@ class PairAnalyzer(NamedObject): merged = pd.merge(df_a, df_b, on="tstamp", how="inner").sort_values( "tstamp" ) + # Log.info(f"{self.fname()}: analyzing {pair_name}") stats = self._compute_stats(inst_a, inst_b, pair_name, merged) if stats: results[pair_name] = stats @@ -289,7 +329,7 @@ class PairAnalyzer(NamedObject): self._assign_ranks(ranked, key=lambda r: r.pvalue_adf_, attr="rank_adf_") self._assign_ranks(ranked, key=lambda r: r.pvalue_j_, attr="rank_j_") for res in ranked: - res.composite_rank_ = res.rank_eg_ + res.rank_adf_ + res.rank_j_ + res.composite_rank_ = res.rank_eg_ + res.rank_adf_ # + res.rank_j_ ranked.sort(key=lambda r: r.composite_rank_) return {res.pair_name_: res for res in ranked} @@ -402,17 +442,15 @@ class PairSelectionEngine(NamedObject): def _extract_price( self, aggr: MdTradesAggregate, inst: ExchangeInstrument ) -> float: - price_field = self.price_field_ - # MdTradesAggregate inherits hist bar with fields open_, high_, low_, close_, vwap_ - field_map = { - "open": aggr.open_, - "high": aggr.high_, - "low": aggr.low_, - "close": aggr.close_, - "vwap": aggr.vwap_, - } - raw = field_map.get(price_field, aggr.close_) - return inst.get_price(raw) + return _extract_price_from_fields( + price_field=self.price_field_, + inst=inst, + open=aggr.open_, + high=aggr.high_, + low=aggr.low_, + close=aggr.close_, + vwap=aggr.vwap_, + ) def sleep_seconds_until_next_cycle(self) -> float: now_ns = current_nanoseconds() @@ -443,13 +481,356 @@ class PairSelectionEngine(NamedObject): } +class PairSelectionBacktest(NamedObject): + config_: object + instruments_: List[ExchangeInstrument] + price_field_: str + input_db_: str + output_db_: str + interval_sec_: int + history_depth_hours_: int + quality_: QualityChecker + analyzer_: PairAnalyzer + inst_by_key_: Dict[Tuple[str, str], ExchangeInstrument] + inst_by_id_: Dict[str, Optional[ExchangeInstrument]] + ambiguous_ids_: Set[str] + + def __init__( + self, + config: Config, + instruments: List[ExchangeInstrument], + price_field: str, + input_db: str, + output_db: str, + ) -> None: + self.config_ = config + self.instruments_ = instruments + self.price_field_ = price_field + self.input_db_ = input_db + self.output_db_ = output_db + + interval_sec = int(config.get_value("interval_sec", 0)) + if interval_sec <= 0: + Log.warning( + f"{self.fname()}: interval_sec not set; defaulting to 60 seconds" + ) + interval_sec = 60 + history_depth_hours = int(config.get_value("history_depth_hours", 0)) + assert history_depth_hours > 0, "history_depth_hours must be > 0" + + self.interval_sec_ = interval_sec + self.history_depth_hours_ = history_depth_hours + self.quality_ = QualityChecker(interval_sec=interval_sec) + self.analyzer_ = PairAnalyzer( + price_field=price_field, interval_sec=interval_sec + ) + + self.inst_by_key_ = { + (inst.exchange_id_, inst.instrument_id()): inst for inst in instruments + } + self.inst_by_id_ = {} + self.ambiguous_ids_ = set() + for inst in instruments: + inst_id = inst.instrument_id() + if inst_id in self.inst_by_id_: + existing = self.inst_by_id_[inst_id] + if existing is not None and existing.exchange_id_ != inst.exchange_id_: + self.inst_by_id_[inst_id] = None + self.ambiguous_ids_.add(inst_id) + elif inst_id not in self.ambiguous_ids_: + self.inst_by_id_[inst_id] = inst + + if self.ambiguous_ids_: + Log.warning( + f"{self.fname()}: ambiguous instrument_id(s) without exchange_id: " + f"{sorted(self.ambiguous_ids_)}" + ) + + def run(self) -> None: + df = self._load_input_df() + if df.empty: + Log.warning(f"{self.fname()}: no rows in md_1min_bars") + return + + df = self._filter_instruments(df) + if df.empty: + Log.warning(f"{self.fname()}: no rows after instrument filtering") + return + + conn = self._init_output_db() + try: + self._run_backtest(df, conn) + finally: + conn.commit() + conn.close() + + def _load_input_df(self) -> pd.DataFrame: + if not os.path.exists(self.input_db_): + raise FileNotFoundError(f"input_db not found: {self.input_db_}") + with sqlite3.connect(self.input_db_) as conn: + df = pd.read_sql_query( + """ + SELECT + tstamp, + tstamp_ns, + exchange_id, + instrument_id, + open, + high, + low, + close, + volume, + vwap, + num_trades + FROM md_1min_bars + """, + conn, + ) + if df.empty: + return df + + ts_ns = pd.to_datetime(df["tstamp_ns"], unit="ns", utc=True, errors="coerce") + ts_txt = pd.to_datetime(df["tstamp"], utc=True, errors="coerce") + df["tstamp"] = ts_ns.fillna(ts_txt) + df = df.dropna(subset=["tstamp", "instrument_id"]).copy() + df["exchange_id"] = df["exchange_id"].fillna("") + df["instrument_id"] = df["instrument_id"].astype(str) + df["tstamp_ns"] = df["tstamp"].astype("int64") + return df.sort_values("tstamp").reset_index(drop=True) + + def _filter_instruments(self, df: pd.DataFrame) -> pd.DataFrame: + instrument_ids = {inst.instrument_id() for inst in self.instruments_} + df = df[df["instrument_id"].isin(instrument_ids)].copy() + if "exchange_id" in df.columns: + exchange_ids = {inst.exchange_id_ for inst in self.instruments_} + df = df[ + (df["exchange_id"].isin(exchange_ids)) | (df["exchange_id"] == "") + ].copy() + return df + + def _init_output_db(self) -> sqlite3.Connection: + if os.path.exists(self.output_db_): + os.remove(self.output_db_) + conn = sqlite3.connect(self.output_db_) + conn.execute( + """ + CREATE TABLE pair_selection_history ( + tstamp TEXT, + tstamp_ns INTEGER, + pair_name TEXT, + exchange_a TEXT, + instrument_a TEXT, + exchange_b TEXT, + instrument_b TEXT, + pvalue_eg REAL, + pvalue_adf REAL, + pvalue_j REAL, + trace_stat_j REAL, + rank_eg INTEGER, + rank_adf INTEGER, + rank_j INTEGER, + composite_rank REAL + ) + """ + ) + conn.execute( + """ + CREATE INDEX idx_pair_selection_history_pair_name + ON pair_selection_history (pair_name) + """ + ) + conn.execute( + """ + CREATE UNIQUE INDEX idx_pair_selection_history_tstamp_pair + ON pair_selection_history (tstamp, pair_name) + """ + ) + conn.commit() + return conn + + def _resolve_instrument( + self, exchange_id: str, instrument_id: str + ) -> Optional[ExchangeInstrument]: + if exchange_id: + inst = self.inst_by_key_.get((exchange_id, instrument_id)) + if inst is not None: + return inst + inst = self.inst_by_id_.get(instrument_id) + if inst is None and instrument_id in self.ambiguous_ids_: + return None + return inst + + def _build_day_series( + self, df_day: pd.DataFrame + ) -> Dict[ExchangeInstrument, pd.DataFrame]: + series: Dict[ExchangeInstrument, pd.DataFrame] = {} + group_cols = ["exchange_id", "instrument_id"] + for key, group in df_day.groupby(group_cols, dropna=False): + exchange_id, instrument_id = key + inst = self._resolve_instrument(str(exchange_id or ""), str(instrument_id)) + if inst is None: + continue + df_inst = group.copy() + df_inst["price"] = [ + _extract_price_from_fields( + price_field=self.price_field_, + inst=inst, + open=float(row.open), #type: ignore + high=float(row.high), #type: ignore + low=float(row.low), #type: ignore + close=float(row.close), #type: ignore + vwap=float(row.vwap),#type: ignore + ) + for row in df_inst.itertuples(index=False) + ] + df_inst = df_inst[["tstamp", "tstamp_ns", "price", "num_trades"]] + if inst in series: + series[inst] = pd.concat([series[inst], df_inst], ignore_index=True) + else: + series[inst] = df_inst + for inst in list(series.keys()): + series[inst] = series[inst].sort_values("tstamp").reset_index(drop=True) + return series + + def _run_backtest(self, df: pd.DataFrame, conn: sqlite3.Connection) -> None: + window_minutes = self.history_depth_hours_ * 60 + window_td = pd.Timedelta(minutes=window_minutes) + step_td = pd.Timedelta(seconds=self.interval_sec_) + + df = df.copy() + df["day"] = df["tstamp"].dt.normalize() + days = sorted(df["day"].unique()) + for day in days: + day_label = pd.Timestamp(day).date() + df_day = df[df["day"] == day] + t0 = df_day["tstamp"].min() + t_last = df_day["tstamp"].max() + if t_last - t0 < window_td: + Log.warning( + f"{self.fname()}: skipping {day_label} (insufficient data)" + ) + continue + + day_series = self._build_day_series(df_day) + if len(day_series) < 2: + Log.warning( + f"{self.fname()}: skipping {day_label} (insufficient instruments)" + ) + continue + + start = t0 + expected_end = start + window_td + while expected_end <= t_last: + window_slices: Dict[ExchangeInstrument, pd.DataFrame] = {} + ts: Optional[pd.Timestamp] = None + for inst, df_inst in day_series.items(): + df_win = df_inst[ + (df_inst["tstamp"] >= start) + & (df_inst["tstamp"] < expected_end) + ] + if df_win.empty: + continue + window_slices[inst] = df_win + last_ts = df_win["tstamp"].iloc[-1] + if ts is None or last_ts > ts: + ts = last_ts + + if window_slices and ts is not None: + price_series: Dict[ExchangeInstrument, pd.DataFrame] = {} + for inst, df_win in window_slices.items(): + aggr = self._to_backtest_aggregates(df_win) + q = self.quality_.evaluate( + inst=inst, aggr=aggr, now_ts=ts + ) + if q.status_ != "PASS": + continue + price_series[inst] = df_win[["tstamp", "price"]] + pair_results = self.analyzer_.analyze(price_series) + Log.info(f"{self.fname()}: Saving Results for window ending {ts}") + self._insert_results(conn, ts, pair_results) + + start = start + step_td + expected_end = start + window_td + + @staticmethod + def _to_backtest_aggregates(df_win: pd.DataFrame) -> List[BacktestAggregate]: + aggr: List[BacktestAggregate] = [] + for tstamp_ns, num_trades in zip(df_win["tstamp_ns"], df_win["num_trades"]): + nt = None if pd.isna(num_trades) else int(num_trades) + aggr.append( + BacktestAggregate(aggr_time_ns_=int(tstamp_ns), num_trades_=nt) + ) + return aggr + + @staticmethod + def _insert_results( + conn: sqlite3.Connection, + ts: pd.Timestamp, + pair_results: Dict[str, PairStats], + ) -> None: + if not pair_results: + return + iso = ts.isoformat() + ns = int(ts.value) + rows = [] + for pair_name in sorted(pair_results.keys()): + stats = pair_results[pair_name] + rows.append( + ( + iso, + ns, + pair_name, + stats.instrument_a_.exchange_id_, + stats.instrument_a_.instrument_id(), + stats.instrument_b_.exchange_id_, + stats.instrument_b_.instrument_id(), + stats.pvalue_eg_, + stats.pvalue_adf_, + stats.pvalue_j_, + stats.trace_stat_j_, + stats.rank_eg_, + stats.rank_adf_, + stats.rank_j_, + stats.composite_rank_, + ) + ) + conn.executemany( + """ + INSERT INTO pair_selection_history ( + tstamp, + tstamp_ns, + pair_name, + exchange_a, + instrument_a, + exchange_b, + instrument_b, + pvalue_eg, + pvalue_adf, + pvalue_j, + trace_stat_j, + rank_eg, + rank_adf, + rank_j, + composite_rank + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + rows, + ) + conn.commit() + + + class PairSelector(NamedObject): instruments_: List[ExchangeInstrument] engine_: PairSelectionEngine - rest_service_: RestService + rest_service_: Optional[RestService] + backtest_: Optional[PairSelectionBacktest] def __init__(self) -> None: App.instance().add_cmdline_arg("--oneshot", action="store_true", default=False) + App.instance().add_cmdline_arg("--backtest", action="store_true", default=False) + App.instance().add_cmdline_arg("--input_db", default=None) + App.instance().add_cmdline_arg("--output_db", default=None) App.instance().add_call(App.Stage.Config, self._on_config()) App.instance().add_call(App.Stage.Run, self.run()) @@ -458,6 +839,24 @@ class PairSelector(NamedObject): self.instruments_ = self._load_instruments(cfg) price_field = cfg.get_value("model/stat_model_price", "close") + self.backtest_ = None + self.rest_service_ = None + if App.instance().get_argument("backtest", False): + input_db = App.instance().get_argument("input_db", None) + output_db = App.instance().get_argument("output_db", None) + if not input_db or not output_db: + raise ValueError( + "--input_db and --output_db are required when --backtest is set" + ) + self.backtest_ = PairSelectionBacktest( + config=cfg, + instruments=self.instruments_, + price_field=price_field, + input_db=input_db, + output_db=output_db, + ) + return + self.engine_ = PairSelectionEngine( config=cfg, instruments=self.instruments_, @@ -499,6 +898,11 @@ class PairSelector(NamedObject): return instruments async def run(self) -> None: + if App.instance().get_argument("backtest", False): + if self.backtest_ is None: + raise RuntimeError("backtest runner not initialized") + self.backtest_.run() + return oneshot = App.instance().get_argument("oneshot", False) while True: await self.engine_.run_once() diff --git a/apps/pair_selector/pair_selector_engine.py.md b/apps/pair_selector/pair_selector_engine.py.md deleted file mode 100644 index 25cdf71..0000000 --- a/apps/pair_selector/pair_selector_engine.py.md +++ /dev/null @@ -1,394 +0,0 @@ -```python -from __future__ import annotations -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np -import pandas as pd -from statsmodels.tsa.stattools import adfuller, coint -from statsmodels.tsa.vector_ar.vecm import coint_johansen -from statsmodels.tsa.vector_ar.vecm import coint_johansen # type: ignore -# --- -from cvttpy_tools.base import NamedObject -from cvttpy_tools.config import Config -from cvttpy_tools.logger import Log -from cvttpy_tools.timeutils import NanoPerSec, SecPerHour, current_nanoseconds -from cvttpy_tools.web.rest_client import RESTSender -# --- -from cvttpy_trading.trading.instrument import ExchangeInstrument -from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate, MdSummary - - -@dataclass -class InstrumentQuality(NamedObject): - instrument_: ExchangeInstrument - record_count_: int - latest_tstamp_: Optional[pd.Timestamp] - status_: str - reason_: str - - -@dataclass -class PairStats(NamedObject): - instrument_a_: ExchangeInstrument - instrument_b_: ExchangeInstrument - pvalue_eg_: Optional[float] - pvalue_adf_: Optional[float] - pvalue_j_: Optional[float] - trace_stat_j_: Optional[float] - rank_eg_: int = 0 - rank_adf_: int = 0 - rank_j_: int = 0 - composite_rank_: int = 0 - - def as_dict(self) -> Dict[str, Any]: - return { - "instrument_a": self.instrument_a_.instrument_id(), - "instrument_b": self.instrument_b_.instrument_id(), - "pvalue_eg": self.pvalue_eg_, - "pvalue_adf": self.pvalue_adf_, - "pvalue_j": self.pvalue_j_, - "trace_stat_j": self.trace_stat_j_, - "rank_eg": self.rank_eg_, - "rank_adf": self.rank_adf_, - "rank_j": self.rank_j_, - "composite_rank": self.composite_rank_, - } - - -class DataFetcher(NamedObject): - sender_: RESTSender - interval_sec_: int - history_depth_sec_: int - - def __init__( - self, - base_url: str, - interval_sec: int, - history_depth_sec: int, - ) -> None: - self.sender_ = RESTSender(base_url=base_url) - self.interval_sec_ = interval_sec - self.history_depth_sec_ = history_depth_sec - - def fetch(self, exch_acct: str, inst: ExchangeInstrument) -> List[MdTradesAggregate]: - rqst_data = { - "exch_acct": exch_acct, - "instrument_id": inst.instrument_id(), - "interval_sec": self.interval_sec_, - "history_depth_sec": self.history_depth_sec_, - } - response = self.sender_.send_post(endpoint="md_summary", post_body=rqst_data) - if response.status_code not in (200, 201): - Log.error( - f"{self.fname()}: error {response.status_code} for {inst.details_short()}: {response.text}") - return [] - mdsums: List[MdSummary] = MdSummary.from_REST_response(response=response) - return [ - mdsum.create_md_trades_aggregate( - exch_acct=exch_acct, exch_inst=inst, interval_sec=self.interval_sec_ - ) - for mdsum in mdsums - ] - - -class QualityChecker(NamedObject): - interval_sec_: int - - def __init__(self, interval_sec: int) -> None: - self.interval_sec_ = interval_sec - - def evaluate(self, inst: ExchangeInstrument, aggr: List[MdTradesAggregate]) -> InstrumentQuality: - if len(aggr) == 0: - return InstrumentQuality( - instrument_=inst, - record_count_=0, - latest_tstamp_=None, - status_="FAIL", - reason_="no records", - ) - - aggr_sorted = sorted(aggr, key=lambda a: a.aggr_time_ns_) - - latest_ts = pd.to_datetime(aggr_sorted[-1].aggr_time_ns_, unit="ns", utc=True) - now_ts = pd.Timestamp.utcnow() - recency_cutoff = now_ts - pd.Timedelta(seconds=2 * self.interval_sec_) - if latest_ts <= recency_cutoff: - return InstrumentQuality( - instrument_=inst, - record_count_=len(aggr_sorted), - latest_tstamp_=latest_ts, - status_="FAIL", - reason_=f"stale: latest {latest_ts} <= cutoff {recency_cutoff}", - ) - - gaps_ok, reason = self._check_gaps(aggr_sorted) - status = "PASS" if gaps_ok else "FAIL" - return InstrumentQuality( - instrument_=inst, - record_count_=len(aggr_sorted), - latest_tstamp_=latest_ts, - status_=status, - reason_=reason, - ) - - def _check_gaps(self, aggr: List[MdTradesAggregate]) -> Tuple[bool, str]: - NUM_TRADES_THRESHOLD = 50 - if len(aggr) < 2: - return True, "ok" - - interval_ns = self.interval_sec_ * NanoPerSec - for idx in range(1, len(aggr)): - prev = aggr[idx - 1] - curr = aggr[idx] - delta = curr.aggr_time_ns_ - prev.aggr_time_ns_ - missing_intervals = int(delta // interval_ns) - 1 - if missing_intervals <= 0: - continue - - prev_nt = prev.num_trades_ - next_nt = curr.num_trades_ - estimate = self._approximate_num_trades(prev_nt, next_nt) - if estimate > NUM_TRADES_THRESHOLD: - return False, ( - f"gap of {missing_intervals} interval(s), est num_trades={estimate} > {NUM_TRADES_THRESHOLD}" - ) - return True, "ok" - - @staticmethod - def _approximate_num_trades(prev_nt: int, next_nt: int) -> float: - if prev_nt is None and next_nt is None: - return 0.0 - if prev_nt is None: - return float(next_nt) - if next_nt is None: - return float(prev_nt) - return (prev_nt + next_nt) / 2.0 - - -class PairAnalyzer(NamedObject): - price_field_: str - interval_sec_: int - - def __init__(self, price_field: str, interval_sec: int) -> None: - self.price_field_ = price_field - self.interval_sec_ = interval_sec - - def analyze(self, series: Dict[ExchangeInstrument, pd.DataFrame]) -> List[PairStats]: - instruments = list(series.keys()) - results: List[PairStats] = [] - for i in range(len(instruments)): - for j in range(i + 1, len(instruments)): - inst_a = instruments[i] - inst_b = instruments[j] - df_a = series[inst_a][["tstamp", "price"]].rename( - columns={"price": "price_a"} - ) - df_b = series[inst_b][["tstamp", "price"]].rename( - columns={"price": "price_b"} - ) - merged = pd.merge(df_a, df_b, on="tstamp", how="inner").sort_values( - "tstamp" - ) - stats = self._compute_stats(inst_a, inst_b, merged) - if stats: - results.append(stats) - self._rank(results) - return results - - def _compute_stats( - self, - inst_a: ExchangeInstrument, - inst_b: ExchangeInstrument, - merged: pd.DataFrame, - ) -> Optional[PairStats]: - if len(merged) < 2: - return None - px_a = merged["price_a"].astype(float) - px_b = merged["price_b"].astype(float) - - std_a = float(px_a.std()) - std_b = float(px_b.std()) - if std_a == 0 or std_b == 0: - return None - - z_a = (px_a - float(px_a.mean())) / std_a - z_b = (px_b - float(px_b.mean())) / std_b - - p_eg: Optional[float] - p_adf: Optional[float] - p_j: Optional[float] - trace_stat: Optional[float] - - try: - p_eg = float(coint(z_a, z_b)[1]) - except Exception as exc: - Log.warning(f"{self.fname()}: EG failed for {inst_a.details_short()}/{inst_b.details_short()}: {exc}") - p_eg = None - - try: - spread = z_a - z_b - p_adf = float(adfuller(spread, maxlag=1, regression="c")[1]) - except Exception as exc: - Log.warning(f"{self.fname()}: ADF failed for {inst_a.details_short()}/{inst_b.details_short()}: {exc}") - p_adf = None - - try: - data = np.column_stack([z_a, z_b]) - res = coint_johansen(data, det_order=0, k_ar_diff=1) - trace_stat = float(res.lr1[0]) - cv10, cv5, cv1 = res.cvt[0] - if trace_stat > cv1: - p_j = 0.01 - elif trace_stat > cv5: - p_j = 0.05 - elif trace_stat > cv10: - p_j = 0.10 - else: - p_j = 1.0 - except Exception as exc: - Log.warning(f"{self.fname()}: Johansen failed for {inst_a.details_short()}/{inst_b.details_short()}: {exc}") - p_j = None - trace_stat = None - - return PairStats( - instrument_a_=inst_a, - instrument_b_=inst_b, - pvalue_eg_=p_eg, - pvalue_adf_=p_adf, - pvalue_j_=p_j, - trace_stat_j_=trace_stat, - ) - - def _rank(self, results: List[PairStats]) -> None: - self._assign_ranks(results, key=lambda r: r.pvalue_eg_, attr="rank_eg_") - self._assign_ranks(results, key=lambda r: r.pvalue_adf_, attr="rank_adf_") - self._assign_ranks(results, key=lambda r: r.pvalue_j_, attr="rank_j_") - for res in results: - res.composite_rank_ = res.rank_eg_ + res.rank_adf_ + res.rank_j_ - results.sort(key=lambda r: r.composite_rank_) - - @staticmethod - def _assign_ranks( - results: List[PairStats], key, attr: str - ) -> None: - values = [key(r) for r in results] - sorted_vals = sorted([v for v in values if v is not None]) - for res in results: - val = key(res) - if val is None: - setattr(res, attr, len(sorted_vals) + 1) - continue - rank = 1 + sum(1 for v in sorted_vals if v < val) - setattr(res, attr, rank) - - -class PairSelectionEngine(NamedObject): - config_: object - instruments_: List[ExchangeInstrument] - price_field_: str - fetcher_: DataFetcher - quality_: QualityChecker - analyzer_: PairAnalyzer - interval_sec_: int - history_depth_sec_: int - data_quality_cache_: List[InstrumentQuality] - pair_results_cache_: List[PairStats] - - def __init__( - self, - config: Config, - instruments: List[ExchangeInstrument], - price_field: str, - ) -> None: - self.config_ = config - self.instruments_ = instruments - self.price_field_ = price_field - - interval_sec = int(config.get_value("interval_sec", 0)) - history_depth_sec = int(config.get_value("history_depth_hours", 0)) * SecPerHour - base_url = config.get_value("cvtt_base_url", None) - assert interval_sec > 0, "interval_sec must be > 0" - assert history_depth_sec > 0, "history_depth_sec must be > 0" - assert base_url, "cvtt_base_url must be set" - - self.fetcher_ = DataFetcher( - base_url=base_url, - interval_sec=interval_sec, - history_depth_sec=history_depth_sec, - ) - self.quality_ = QualityChecker(interval_sec=interval_sec) - self.analyzer_ = PairAnalyzer(price_field=price_field, interval_sec=interval_sec) - - self.interval_sec_ = interval_sec - self.history_depth_sec_ = history_depth_sec - - self.data_quality_cache_ = [] - self.pair_results_cache_ = [] - - async def run_once(self) -> None: - quality_results: List[InstrumentQuality] = [] - price_series: Dict[ExchangeInstrument, pd.DataFrame] = {} - - for inst in self.instruments_: - exch_acct = inst.user_data_.get("exch_acct") or inst.exchange_id_ - aggr = self.fetcher_.fetch(exch_acct=exch_acct, inst=inst) - q = self.quality_.evaluate(inst, aggr) - quality_results.append(q) - if q.status_ != "PASS": - continue - df = self._to_dataframe(aggr, inst) - if len(df) > 0: - price_series[inst] = df - self.data_quality_cache_ = quality_results - self.pair_results_cache_ = self.analyzer_.analyze(price_series) - - def _to_dataframe(self, aggr: List[MdTradesAggregate], inst: ExchangeInstrument) -> pd.DataFrame: - rows: List[Dict[str, Any]] = [] - for item in aggr: - rows.append( - { - "tstamp": pd.to_datetime(item.aggr_time_ns_, unit="ns", utc=True), - "price": self._extract_price(item, inst), - "num_trades": item.num_trades_, - } - ) - df = pd.DataFrame(rows) - return df.sort_values("tstamp").reset_index(drop=True) - - def _extract_price(self, aggr: MdTradesAggregate, inst: ExchangeInstrument) -> float: - price_field = self.price_field_ - # MdTradesAggregate inherits hist bar with fields open_, high_, low_, close_, vwap_ - field_map = { - "open": aggr.open_, - "high": aggr.high_, - "low": aggr.low_, - "close": aggr.close_, - "vwap": aggr.vwap_, - } - raw = field_map.get(price_field, aggr.close_) - return inst.get_price(raw) - - def sleep_seconds_until_next_cycle(self) -> float: - now_ns = current_nanoseconds() - interval_ns = self.interval_sec_ * NanoPerSec - next_boundary = (now_ns // interval_ns + 1) * interval_ns - return max(0.0, (next_boundary - now_ns) / NanoPerSec) - - def quality_dicts(self) -> List[Dict[str, Any]]: - res: List[Dict[str, Any]] = [] - for q in self.data_quality_cache_: - res.append( - { - "instrument": q.instrument_.instrument_id(), - "record_count": q.record_count_, - "latest_tstamp": q.latest_tstamp_.isoformat() if q.latest_tstamp_ else None, - "status": q.status_, - "reason": q.reason_, - } - ) - return res - - def pair_dicts(self) -> List[Dict[str, Any]]: - return [p.as_dict() for p in self.pair_results_cache_] -``` \ No newline at end of file diff --git a/research/notebooks/pair_select_hist.ipynb b/research/notebooks/pair_select_hist.ipynb new file mode 100644 index 0000000..64cef3d --- /dev/null +++ b/research/notebooks/pair_select_hist.ipynb @@ -0,0 +1,329 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Pair Selection History\n", + "\n", + "Interactive notebook for exploring pair selection history from a SQLite database.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Usage**\n", + "- Enter the SQLite `db_path` (file path).\n", + "- Click `Load pairs` to populate the dropdown.\n", + "- Select a `pair_name`, then click `Plot`.\n" + ] + }, + { + "cell_type": "markdown", + "id": "668ebf19", + "metadata": {}, + "source": [ + "# Settings" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c78db847", + "metadata": {}, + "outputs": [], + "source": [ + "import sqlite3\n", + "from pathlib import Path\n", + "\n", + "import pandas as pd\n", + "import plotly.express as px\n", + "import ipywidgets as widgets\n", + "from IPython.display import display\n" + ] + }, + { + "cell_type": "markdown", + "id": "e7ac6adc", + "metadata": {}, + "source": [ + "# Data Loading" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "766bcf9f", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a0582ba92c744e08b9267176f463701a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Text(value='', description='pair_db', layout=Layout(width='80%'), placeholder='/path/to/pairs.d…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "eeaaaf780af64a278b0948c0c787cf33", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "db_path = widgets.Text(\n", + " value='',\n", + " placeholder='/path/to/pairs.db',\n", + " description='pair_db',\n", + " layout=widgets.Layout(width='80%')\n", + ")\n", + "\n", + "md_db_path = widgets.Text(\n", + " value='',\n", + " placeholder='/path/to/market_data.db',\n", + " description='md_db',\n", + " layout=widgets.Layout(width='80%')\n", + ")\n", + "\n", + "load_button = widgets.Button(description='Load pairs', button_style='info')\n", + "plot_button = widgets.Button(description='Plot', button_style='primary')\n", + "\n", + "pair_name = widgets.Dropdown(\n", + " options=[],\n", + " value=None,\n", + " description='pair_name',\n", + " layout=widgets.Layout(width='80%')\n", + ")\n", + "\n", + "status = widgets.HTML(value='')\n", + "output = widgets.Output()\n", + "\n", + "controls = widgets.VBox([\n", + " db_path,\n", + " md_db_path,\n", + " widgets.HBox([load_button, plot_button]),\n", + " pair_name,\n", + " status,\n", + "])\n", + "\n", + "display(controls, output)\n" + ] + }, + { + "cell_type": "markdown", + "id": "a4d47855", + "metadata": {}, + "source": [ + "# Processing" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2c710f51", + "metadata": {}, + "outputs": [], + "source": [ + "PLOT_WIDTH = 1100\n", + "PLOT_HEIGHT = 320\n", + "\n", + "def _connect(path: str):\n", + " if not path:\n", + " raise ValueError('Please provide db_path.')\n", + " p = Path(path).expanduser().resolve()\n", + " if not p.exists():\n", + " raise FileNotFoundError(f'Database not found: {p}')\n", + " return sqlite3.connect(p)\n", + "\n", + "\n", + "def _parse_tstamp(series: pd.Series) -> pd.Series:\n", + " return pd.to_datetime(series, utc=True, errors='coerce').dt.tz_convert(None)\n", + "\n", + "\n", + "def _style_fig(fig, tmin, tmax):\n", + " fig.update_layout(\n", + " legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='left', x=0),\n", + " margin=dict(l=50, r=20, t=60, b=40),\n", + " height=PLOT_HEIGHT,\n", + " width=PLOT_WIDTH,\n", + " )\n", + " fig.update_xaxes(range=[tmin, tmax])\n", + "\n", + "\n", + "def _load_pairs(_=None):\n", + " status.value = ''\n", + " with output:\n", + " output.clear_output()\n", + " try:\n", + " with _connect(db_path.value) as conn:\n", + " rows = conn.execute(\n", + " \"WITH first_rows AS (\"\n", + " \" SELECT pair_name, MIN(tstamp) AS tmin \"\n", + " \" FROM pair_selection_history \"\n", + " \" GROUP BY pair_name\"\n", + " \" ) \"\n", + " \"SELECT p.pair_name \"\n", + " \"FROM pair_selection_history p \"\n", + " \"JOIN first_rows f \"\n", + " \" ON p.pair_name = f.pair_name AND p.tstamp = f.tmin \"\n", + " \"GROUP BY p.pair_name \"\n", + " \"ORDER BY MIN(p.composite_rank), p.pair_name\"\n", + " ).fetchall()\n", + " options = [r[0] for r in rows]\n", + " pair_name.options = options\n", + " pair_name.value = options[0] if options else None\n", + " status.value = f'Loaded {len(options)} pairs.'\n", + " except Exception as exc:\n", + " status.value = f\"Error: {exc}\"\n", + "\n", + "\n", + "def _plot(_=None):\n", + " status.value = ''\n", + " with output:\n", + " output.clear_output()\n", + " try:\n", + " if not pair_name.value:\n", + " raise ValueError('Please select a pair_name.')\n", + " if not md_db_path.value:\n", + " raise ValueError('Please provide md_db path.')\n", + " query = (\n", + " 'SELECT tstamp, pvalue_eg, pvalue_adf, rank_eg, rank_adf, '\n", + " 'exchange_a, instrument_a, exchange_b, instrument_b '\n", + " 'FROM pair_selection_history '\n", + " 'WHERE pair_name = ? '\n", + " 'ORDER BY tstamp'\n", + " )\n", + " with _connect(db_path.value) as conn:\n", + " df = pd.read_sql_query(query, conn, params=(pair_name.value,))\n", + " if df.empty:\n", + " raise ValueError('No data for selected pair_name.')\n", + " df['tstamp'] = _parse_tstamp(df['tstamp'])\n", + " df = df.dropna(subset=['tstamp'])\n", + " if df.empty:\n", + " raise ValueError('No valid timestamps in pair selection data.')\n", + " tmin = df['tstamp'].min()\n", + " tmax = df['tstamp'].max()\n", + "\n", + " first_row = df.dropna(subset=['exchange_a', 'instrument_a', 'exchange_b', 'instrument_b']).iloc[0]\n", + " ex_a = first_row['exchange_a']\n", + " id_a = first_row['instrument_a']\n", + " ex_b = first_row['exchange_b']\n", + " id_b = first_row['instrument_b']\n", + "\n", + " fig_p = px.line(\n", + " df,\n", + " x='tstamp',\n", + " y=['pvalue_eg', 'pvalue_adf'],\n", + " title=f'P-Values Over Time: {pair_name.value}',\n", + " labels={'value': 'p-value', 'variable': 'metric', 'tstamp': 'timestamp'}\n", + " )\n", + " fig_p.update_layout(legend_title_text='metric')\n", + " _style_fig(fig_p, tmin, tmax)\n", + "\n", + " # fig_r = px.line(\n", + " # df,\n", + " # x='tstamp',\n", + " # y=['rank_eg', 'rank_adf'],\n", + " # title=f'Ranks Over Time: {pair_name.value}',\n", + " # labels={'value': 'rank', 'variable': 'metric', 'tstamp': 'timestamp'}\n", + " # )\n", + " # fig_r.update_layout(legend_title_text='metric')\n", + " # _style_fig(fig_r, tmin, tmax)\n", + "\n", + " md_query = (\n", + " 'SELECT tstamp, close FROM md_1min_bars '\n", + " 'WHERE exchange_id = ? AND instrument_id = ? '\n", + " 'ORDER BY tstamp'\n", + " )\n", + " with _connect(md_db_path.value) as md_conn:\n", + " md_a = pd.read_sql_query(md_query, md_conn, params=(ex_a, id_a))\n", + " md_b = pd.read_sql_query(md_query, md_conn, params=(ex_b, id_b))\n", + " if md_a.empty or md_b.empty:\n", + " raise ValueError('Market data not found for selected instruments.')\n", + " md_a['tstamp'] = _parse_tstamp(md_a['tstamp'])\n", + " md_b['tstamp'] = _parse_tstamp(md_b['tstamp'])\n", + " md_a = md_a.dropna(subset=['tstamp', 'close'])\n", + " md_b = md_b.dropna(subset=['tstamp', 'close'])\n", + " md_a = md_a[(md_a['tstamp'] >= tmin) & (md_a['tstamp'] <= tmax)]\n", + " md_b = md_b[(md_b['tstamp'] >= tmin) & (md_b['tstamp'] <= tmax)]\n", + " if md_a.empty or md_b.empty:\n", + " raise ValueError('Market data is outside the pair selection time range.')\n", + " md_a = md_a.sort_values('tstamp')\n", + " md_b = md_b.sort_values('tstamp')\n", + " md_a['scaled_close'] = (md_a['close'] - md_a['close'].iloc[0]) / md_a['close'].iloc[0] * 100\n", + " md_b['scaled_close'] = (md_b['close'] - md_b['close'].iloc[0]) / md_b['close'].iloc[0] * 100\n", + "\n", + " md_plot = pd.DataFrame({\n", + " 'tstamp': md_a['tstamp'],\n", + " f'{ex_a}:{id_a}': md_a['scaled_close'],\n", + " })\n", + " md_plot = md_plot.merge(\n", + " pd.DataFrame({\n", + " 'tstamp': md_b['tstamp'],\n", + " f'{ex_b}:{id_b}': md_b['scaled_close'],\n", + " }),\n", + " on='tstamp',\n", + " how='outer'\n", + " ).sort_values('tstamp')\n", + "\n", + " fig_m = px.line(\n", + " md_plot,\n", + " x='tstamp',\n", + " y=[f'{ex_a}:{id_a}', f'{ex_b}:{id_b}'],\n", + " title='Scaled Close Price Change (%)',\n", + " labels={'value': 'scaled % change', 'variable': 'instrument', 'tstamp': 'timestamp'}\n", + " )\n", + " fig_m.update_layout(legend_title_text='instrument')\n", + " _style_fig(fig_m, tmin, tmax)\n", + "\n", + " with output:\n", + " display(fig_p)\n", + " # display(fig_r)\n", + " display(fig_m)\n", + " except Exception as exc:\n", + " status.value = f\"Error: {exc}\"\n", + "\n", + "\n", + "load_button.on_click(_load_pairs)\n", + "plot_button.on_click(_plot)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3.12-venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}