from __future__ import annotations import asyncio from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple from aiohttp import web import numpy as np import pandas as pd from statsmodels.tsa.stattools import adfuller, coint # type: ignore from statsmodels.tsa.vector_ar.vecm import coint_johansen # type: ignore from cvttpy_tools.app import App from cvttpy_tools.base import NamedObject from cvttpy_tools.config import Config, CvttAppConfig from cvttpy_tools.logger import Log from cvttpy_tools.timeutils import NanoPerSec, SecPerHour, current_nanoseconds from cvttpy_tools.web.rest_client import RESTSender from cvttpy_tools.web.rest_service import RestService from cvttpy_trading.trading.exchange_config import ExchangeAccounts from cvttpy_trading.trading.instrument import ExchangeInstrument from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate, MdSummary from pairs_trading.apps.pair_selector.renderer import HtmlRenderer @dataclass class InstrumentQuality(NamedObject): instrument_: ExchangeInstrument record_count_: int latest_tstamp_: Optional[pd.Timestamp] status_: str reason_: str @dataclass class PairStats(NamedObject): pair_name_: str instrument_a_: ExchangeInstrument instrument_b_: ExchangeInstrument pvalue_eg_: Optional[float] pvalue_adf_: Optional[float] pvalue_j_: Optional[float] trace_stat_j_: Optional[float] rank_eg_: int = 0 rank_adf_: int = 0 rank_j_: int = 0 composite_rank_: int = 0 def as_dict(self) -> Dict[str, Any]: return { "instrument_a": self.instrument_a_.instrument_id(), "instrument_b": self.instrument_b_.instrument_id(), "pvalue_eg": self.pvalue_eg_, "pvalue_adf": self.pvalue_adf_, "pvalue_j": self.pvalue_j_, "trace_stat_j": self.trace_stat_j_, "rank_eg": self.rank_eg_, "rank_adf": self.rank_adf_, "rank_j": self.rank_j_, "composite_rank": self.composite_rank_, } class DataFetcher(NamedObject): sender_: RESTSender interval_sec_: int history_depth_sec_: int def __init__( self, base_url: str, interval_sec: int, history_depth_sec: int, ) -> None: self.sender_ = RESTSender(base_url=base_url) self.interval_sec_ = interval_sec self.history_depth_sec_ = history_depth_sec def fetch( self, exch_acct: str, inst: ExchangeInstrument ) -> List[MdTradesAggregate]: rqst_data = { "exch_acct": exch_acct, "instrument_id": inst.instrument_id(), "interval_sec": self.interval_sec_, "history_depth_sec": self.history_depth_sec_, } response = self.sender_.send_post(endpoint="md_summary", post_body=rqst_data) if response.status_code not in (200, 201): Log.error( f"{self.fname()}: error {response.status_code} for {inst.details_short()}: {response.text}" ) return [] mdsums: List[MdSummary] = MdSummary.from_REST_response(response=response) return [ mdsum.create_md_trades_aggregate( exch_acct=exch_acct, exch_inst=inst, interval_sec=self.interval_sec_ ) for mdsum in mdsums ] class QualityChecker(NamedObject): interval_sec_: int def __init__(self, interval_sec: int) -> None: self.interval_sec_ = interval_sec def evaluate( self, inst: ExchangeInstrument, aggr: List[MdTradesAggregate] ) -> InstrumentQuality: if len(aggr) == 0: return InstrumentQuality( instrument_=inst, record_count_=0, latest_tstamp_=None, status_="FAIL", reason_="no records", ) aggr_sorted = sorted(aggr, key=lambda a: a.aggr_time_ns_) latest_ts = pd.to_datetime(aggr_sorted[-1].aggr_time_ns_, unit="ns", utc=True) now_ts = pd.Timestamp.utcnow() recency_cutoff = now_ts - pd.Timedelta(seconds=2 * self.interval_sec_) if latest_ts <= recency_cutoff: return InstrumentQuality( instrument_=inst, record_count_=len(aggr_sorted), latest_tstamp_=latest_ts, status_="FAIL", reason_=f"stale: latest {latest_ts} <= cutoff {recency_cutoff}", ) gaps_ok, reason = self._check_gaps(aggr_sorted) status = "PASS" if gaps_ok else "FAIL" return InstrumentQuality( instrument_=inst, record_count_=len(aggr_sorted), latest_tstamp_=latest_ts, status_=status, reason_=reason, ) def _check_gaps(self, aggr: List[MdTradesAggregate]) -> Tuple[bool, str]: NUM_TRADES_THRESHOLD = 50 if len(aggr) < 2: return True, "ok" interval_ns = self.interval_sec_ * NanoPerSec for idx in range(1, len(aggr)): prev = aggr[idx - 1] curr = aggr[idx] delta = curr.aggr_time_ns_ - prev.aggr_time_ns_ missing_intervals = int(delta // interval_ns) - 1 if missing_intervals <= 0: continue prev_nt = prev.num_trades_ next_nt = curr.num_trades_ estimate = self._approximate_num_trades(prev_nt, next_nt) if estimate > NUM_TRADES_THRESHOLD: return False, ( f"gap of {missing_intervals} interval(s), est num_trades={estimate} > {NUM_TRADES_THRESHOLD}" ) return True, "ok" @staticmethod def _approximate_num_trades(prev_nt: int, next_nt: int) -> float: if prev_nt is None and next_nt is None: return 0.0 if prev_nt is None: return float(next_nt) if next_nt is None: return float(prev_nt) return (prev_nt + next_nt) / 2.0 class PairAnalyzer(NamedObject): price_field_: str interval_sec_: int def __init__(self, price_field: str, interval_sec: int) -> None: self.price_field_ = price_field self.interval_sec_ = interval_sec def analyze( self, series: Dict[ExchangeInstrument, pd.DataFrame] ) -> Dict[str, PairStats]: instruments = list(series.keys()) results: Dict[str, PairStats] = {} for i in range(len(instruments)): for j in range(i + 1, len(instruments)): inst_a, inst_b, pair_name = self._normalized_pair( instruments[i], instruments[j] ) df_a = series[inst_a][["tstamp", "price"]].rename( columns={"price": "price_a"} ) df_b = series[inst_b][["tstamp", "price"]].rename( columns={"price": "price_b"} ) merged = pd.merge(df_a, df_b, on="tstamp", how="inner").sort_values( "tstamp" ) stats = self._compute_stats(inst_a, inst_b, pair_name, merged) if stats: results[pair_name] = stats return self._rank(results) def _compute_stats( self, inst_a: ExchangeInstrument, inst_b: ExchangeInstrument, pair_name: str, merged: pd.DataFrame, ) -> Optional[PairStats]: if len(merged) < 2: return None px_a = merged["price_a"].astype(float) px_b = merged["price_b"].astype(float) std_a = float(px_a.std()) std_b = float(px_b.std()) if std_a == 0 or std_b == 0: return None z_a = (px_a - float(px_a.mean())) / std_a z_b = (px_b - float(px_b.mean())) / std_b p_eg: Optional[float] p_adf: Optional[float] p_j: Optional[float] trace_stat: Optional[float] try: p_eg = float(coint(z_a, z_b)[1]) except Exception as exc: Log.warning( f"{self.fname()}: EG failed for {inst_a.details_short()}/{inst_b.details_short()}: {exc}" ) p_eg = None try: spread = z_a - z_b p_adf = float(adfuller(spread, maxlag=1, regression="c")[1]) except Exception as exc: Log.warning( f"{self.fname()}: ADF failed for {inst_a.details_short()}/{inst_b.details_short()}: {exc}" ) p_adf = None try: data = np.column_stack([z_a, z_b]) res = coint_johansen(data, det_order=0, k_ar_diff=1) trace_stat = float(res.lr1[0]) cv10, cv5, cv1 = res.cvt[0] if trace_stat > cv1: p_j = 0.01 elif trace_stat > cv5: p_j = 0.05 elif trace_stat > cv10: p_j = 0.10 else: p_j = 1.0 except Exception as exc: Log.warning( f"{self.fname()}: Johansen failed for {inst_a.details_short()}/{inst_b.details_short()}: {exc}" ) p_j = None trace_stat = None return PairStats( pair_name_=pair_name, instrument_a_=inst_a, instrument_b_=inst_b, pvalue_eg_=p_eg, pvalue_adf_=p_adf, pvalue_j_=p_j, trace_stat_j_=trace_stat, ) def _rank(self, results: Dict[str, PairStats]) -> Dict[str, PairStats]: ranked = list(results.values()) self._assign_ranks(ranked, key=lambda r: r.pvalue_eg_, attr="rank_eg_") self._assign_ranks(ranked, key=lambda r: r.pvalue_adf_, attr="rank_adf_") self._assign_ranks(ranked, key=lambda r: r.pvalue_j_, attr="rank_j_") for res in ranked: res.composite_rank_ = res.rank_eg_ + res.rank_adf_ + res.rank_j_ ranked.sort(key=lambda r: r.composite_rank_) return {res.pair_name_: res for res in ranked} @staticmethod def _normalized_pair( inst_a: ExchangeInstrument, inst_b: ExchangeInstrument ) -> Tuple[ExchangeInstrument, ExchangeInstrument, str]: inst_a_id = PairAnalyzer._pair_label(inst_a.instrument_id()) inst_b_id = PairAnalyzer._pair_label(inst_b.instrument_id()) if inst_a_id <= inst_b_id: return inst_a, inst_b, f"{inst_a_id}<->{inst_b_id}" return inst_b, inst_a, f"{inst_b_id}<->{inst_a_id}" @staticmethod def _pair_label(instrument_id: str) -> str: if instrument_id.startswith("PAIR-"): return instrument_id[len("PAIR-") :] return instrument_id @staticmethod def _assign_ranks(results: List[PairStats], key, attr: str) -> None: values = [key(r) for r in results] sorted_vals = sorted([v for v in values if v is not None]) for res in results: val = key(res) if val is None: setattr(res, attr, len(sorted_vals) + 1) continue rank = 1 + sum(1 for v in sorted_vals if v < val) setattr(res, attr, rank) class PairSelectionEngine(NamedObject): config_: object instruments_: List[ExchangeInstrument] price_field_: str fetcher_: DataFetcher quality_: QualityChecker analyzer_: PairAnalyzer interval_sec_: int history_depth_sec_: int data_quality_cache_: List[InstrumentQuality] pair_results_cache_: Dict[str, PairStats] def __init__( self, config: Config, instruments: List[ExchangeInstrument], price_field: str, ) -> None: self.config_ = config self.instruments_ = instruments self.price_field_ = price_field interval_sec = int(config.get_value("interval_sec", 0)) history_depth_sec = int(config.get_value("history_depth_hours", 0)) * SecPerHour base_url = config.get_value("cvtt_base_url", None) assert interval_sec > 0, "interval_sec must be > 0" assert history_depth_sec > 0, "history_depth_sec must be > 0" assert base_url, "cvtt_base_url must be set" self.fetcher_ = DataFetcher( base_url=base_url, interval_sec=interval_sec, history_depth_sec=history_depth_sec, ) self.quality_ = QualityChecker(interval_sec=interval_sec) self.analyzer_ = PairAnalyzer( price_field=price_field, interval_sec=interval_sec ) self.interval_sec_ = interval_sec self.history_depth_sec_ = history_depth_sec self.data_quality_cache_ = [] self.pair_results_cache_ = {} async def run_once(self) -> None: quality_results: List[InstrumentQuality] = [] price_series: Dict[ExchangeInstrument, pd.DataFrame] = {} for inst in self.instruments_: exch_acct = inst.user_data_.get("exch_acct") or inst.exchange_id_ aggr = self.fetcher_.fetch(exch_acct=exch_acct, inst=inst) q = self.quality_.evaluate(inst, aggr) quality_results.append(q) if q.status_ != "PASS": continue df = self._to_dataframe(aggr, inst) if len(df) > 0: price_series[inst] = df self.data_quality_cache_ = quality_results self.pair_results_cache_ = self.analyzer_.analyze(price_series) def _to_dataframe( self, aggr: List[MdTradesAggregate], inst: ExchangeInstrument ) -> pd.DataFrame: rows: List[Dict[str, Any]] = [] for item in aggr: rows.append( { "tstamp": pd.to_datetime(item.aggr_time_ns_, unit="ns", utc=True), "price": self._extract_price(item, inst), "num_trades": item.num_trades_, } ) df = pd.DataFrame(rows) return df.sort_values("tstamp").reset_index(drop=True) def _extract_price( self, aggr: MdTradesAggregate, inst: ExchangeInstrument ) -> float: price_field = self.price_field_ # MdTradesAggregate inherits hist bar with fields open_, high_, low_, close_, vwap_ field_map = { "open": aggr.open_, "high": aggr.high_, "low": aggr.low_, "close": aggr.close_, "vwap": aggr.vwap_, } raw = field_map.get(price_field, aggr.close_) return inst.get_price(raw) def sleep_seconds_until_next_cycle(self) -> float: now_ns = current_nanoseconds() interval_ns = self.interval_sec_ * NanoPerSec next_boundary = (now_ns // interval_ns + 1) * interval_ns return max(0.0, (next_boundary - now_ns) / NanoPerSec) def quality_dicts(self) -> List[Dict[str, Any]]: res: List[Dict[str, Any]] = [] for q in self.data_quality_cache_: res.append( { "instrument": q.instrument_.instrument_id(), "record_count": q.record_count_, "latest_tstamp": ( q.latest_tstamp_.isoformat() if q.latest_tstamp_ else None ), "status": q.status_, "reason": q.reason_, } ) return res def pair_dicts(self) -> Dict[str, Dict[str, Any]]: return { pair_name: stats.as_dict() for pair_name, stats in self.pair_results_cache_.items() } class PairSelector(NamedObject): instruments_: List[ExchangeInstrument] engine_: PairSelectionEngine rest_service_: RestService def __init__(self) -> None: App.instance().add_cmdline_arg("--oneshot", action="store_true", default=False) App.instance().add_call(App.Stage.Config, self._on_config()) App.instance().add_call(App.Stage.Run, self.run()) async def _on_config(self) -> None: cfg = CvttAppConfig.instance() self.instruments_ = self._load_instruments(cfg) price_field = cfg.get_value("model/stat_model_price", "close") self.engine_ = PairSelectionEngine( config=cfg, instruments=self.instruments_, price_field=price_field, ) self.rest_service_ = RestService(config_key="/api/REST") self.rest_service_.add_handler("GET", "/data_quality", self._on_data_quality) self.rest_service_.add_handler( "GET", "/pair_selection", self._on_pair_selection ) def _load_instruments(self, cfg: CvttAppConfig) -> List[ExchangeInstrument]: instruments_cfg = cfg.get_value("instruments", []) instruments: List[ExchangeInstrument] = [] assert len(instruments_cfg) >= 2, "at least two instruments required" for item in instruments_cfg: if isinstance(item, str): parts = item.split(":", 1) if len(parts) != 2: raise ValueError(f"invalid instrument format: {item}") exch_acct, instrument_id = parts elif isinstance(item, dict): exch_acct = item.get("exch_acct", "") instrument_id = item.get("instrument_id", "") if not exch_acct or not instrument_id: raise ValueError(f"invalid instrument config: {item}") else: raise ValueError(f"unsupported instrument entry: {item}") exch_inst = ExchangeAccounts.instance().get_exchange_instrument( exch_acct=exch_acct, instrument_id=instrument_id ) assert ( exch_inst is not None ), f"no ExchangeInstrument for {exch_acct}:{instrument_id}" exch_inst.user_data_["exch_acct"] = exch_acct instruments.append(exch_inst) return instruments async def run(self) -> None: oneshot = App.instance().get_argument("oneshot", False) while True: await self.engine_.run_once() if oneshot: break sleep_for = self.engine_.sleep_seconds_until_next_cycle() await asyncio.sleep(sleep_for) async def _on_data_quality(self, request: web.Request) -> web.Response: fmt = request.query.get("format", "html").lower() quality = self.engine_.quality_dicts() if fmt == "json": return web.json_response(quality) return web.Response( text=HtmlRenderer.render_data_quality(quality), content_type="text/html" ) async def _on_pair_selection(self, request: web.Request) -> web.Response: fmt = request.query.get("format", "html").lower() pairs = self.engine_.pair_dicts() if fmt == "json": return web.json_response(pairs) return web.Response( text=HtmlRenderer.render_pairs(pairs), content_type="text/html" ) if __name__ == "__main__": App() CvttAppConfig() PairSelector() App.instance().run()