pairs_trading/apps/pair_selector/pair_selector.py
2026-02-03 20:46:01 +00:00

534 lines
19 KiB
Python

from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from aiohttp import web
import numpy as np
import pandas as pd
from statsmodels.tsa.stattools import adfuller, coint # type: ignore
from statsmodels.tsa.vector_ar.vecm import coint_johansen # type: ignore
from cvttpy_tools.app import App
from cvttpy_tools.base import NamedObject
from cvttpy_tools.config import Config, CvttAppConfig
from cvttpy_tools.logger import Log
from cvttpy_tools.timeutils import NanoPerSec, SecPerHour, current_nanoseconds
from cvttpy_tools.web.rest_client import RESTSender
from cvttpy_tools.web.rest_service import RestService
from cvttpy_trading.trading.exchange_config import ExchangeAccounts
from cvttpy_trading.trading.instrument import ExchangeInstrument
from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate, MdSummary
from pairs_trading.apps.pair_selector.renderer import HtmlRenderer
@dataclass
class InstrumentQuality(NamedObject):
instrument_: ExchangeInstrument
record_count_: int
latest_tstamp_: Optional[pd.Timestamp]
status_: str
reason_: str
@dataclass
class PairStats(NamedObject):
pair_name_: str
instrument_a_: ExchangeInstrument
instrument_b_: ExchangeInstrument
pvalue_eg_: Optional[float]
pvalue_adf_: Optional[float]
pvalue_j_: Optional[float]
trace_stat_j_: Optional[float]
rank_eg_: int = 0
rank_adf_: int = 0
rank_j_: int = 0
composite_rank_: int = 0
def as_dict(self) -> Dict[str, Any]:
return {
"instrument_a": self.instrument_a_.instrument_id(),
"instrument_b": self.instrument_b_.instrument_id(),
"pvalue_eg": self.pvalue_eg_,
"pvalue_adf": self.pvalue_adf_,
"pvalue_j": self.pvalue_j_,
"trace_stat_j": self.trace_stat_j_,
"rank_eg": self.rank_eg_,
"rank_adf": self.rank_adf_,
"rank_j": self.rank_j_,
"composite_rank": self.composite_rank_,
}
class DataFetcher(NamedObject):
sender_: RESTSender
interval_sec_: int
history_depth_sec_: int
def __init__(
self,
base_url: str,
interval_sec: int,
history_depth_sec: int,
) -> None:
self.sender_ = RESTSender(base_url=base_url)
self.interval_sec_ = interval_sec
self.history_depth_sec_ = history_depth_sec
def fetch(
self, exch_acct: str, inst: ExchangeInstrument
) -> List[MdTradesAggregate]:
rqst_data = {
"exch_acct": exch_acct,
"instrument_id": inst.instrument_id(),
"interval_sec": self.interval_sec_,
"history_depth_sec": self.history_depth_sec_,
}
response = self.sender_.send_post(endpoint="md_summary", post_body=rqst_data)
if response.status_code not in (200, 201):
Log.error(
f"{self.fname()}: error {response.status_code} for {inst.details_short()}: {response.text}"
)
return []
mdsums: List[MdSummary] = MdSummary.from_REST_response(response=response)
return [
mdsum.create_md_trades_aggregate(
exch_acct=exch_acct, exch_inst=inst, interval_sec=self.interval_sec_
)
for mdsum in mdsums
]
class QualityChecker(NamedObject):
interval_sec_: int
def __init__(self, interval_sec: int) -> None:
self.interval_sec_ = interval_sec
def evaluate(
self, inst: ExchangeInstrument, aggr: List[MdTradesAggregate]
) -> InstrumentQuality:
if len(aggr) == 0:
return InstrumentQuality(
instrument_=inst,
record_count_=0,
latest_tstamp_=None,
status_="FAIL",
reason_="no records",
)
aggr_sorted = sorted(aggr, key=lambda a: a.aggr_time_ns_)
latest_ts = pd.to_datetime(aggr_sorted[-1].aggr_time_ns_, unit="ns", utc=True)
now_ts = pd.Timestamp.utcnow()
recency_cutoff = now_ts - pd.Timedelta(seconds=2 * self.interval_sec_)
if latest_ts <= recency_cutoff:
return InstrumentQuality(
instrument_=inst,
record_count_=len(aggr_sorted),
latest_tstamp_=latest_ts,
status_="FAIL",
reason_=f"stale: latest {latest_ts} <= cutoff {recency_cutoff}",
)
gaps_ok, reason = self._check_gaps(aggr_sorted)
status = "PASS" if gaps_ok else "FAIL"
return InstrumentQuality(
instrument_=inst,
record_count_=len(aggr_sorted),
latest_tstamp_=latest_ts,
status_=status,
reason_=reason,
)
def _check_gaps(self, aggr: List[MdTradesAggregate]) -> Tuple[bool, str]:
NUM_TRADES_THRESHOLD = 50
if len(aggr) < 2:
return True, "ok"
interval_ns = self.interval_sec_ * NanoPerSec
for idx in range(1, len(aggr)):
prev = aggr[idx - 1]
curr = aggr[idx]
delta = curr.aggr_time_ns_ - prev.aggr_time_ns_
missing_intervals = int(delta // interval_ns) - 1
if missing_intervals <= 0:
continue
prev_nt = prev.num_trades_
next_nt = curr.num_trades_
estimate = self._approximate_num_trades(prev_nt, next_nt)
if estimate > NUM_TRADES_THRESHOLD:
return False, (
f"gap of {missing_intervals} interval(s), est num_trades={estimate} > {NUM_TRADES_THRESHOLD}"
)
return True, "ok"
@staticmethod
def _approximate_num_trades(prev_nt: int, next_nt: int) -> float:
if prev_nt is None and next_nt is None:
return 0.0
if prev_nt is None:
return float(next_nt)
if next_nt is None:
return float(prev_nt)
return (prev_nt + next_nt) / 2.0
class PairAnalyzer(NamedObject):
price_field_: str
interval_sec_: int
def __init__(self, price_field: str, interval_sec: int) -> None:
self.price_field_ = price_field
self.interval_sec_ = interval_sec
def analyze(
self, series: Dict[ExchangeInstrument, pd.DataFrame]
) -> Dict[str, PairStats]:
instruments = list(series.keys())
results: Dict[str, PairStats] = {}
for i in range(len(instruments)):
for j in range(i + 1, len(instruments)):
inst_a, inst_b, pair_name = self._normalized_pair(
instruments[i], instruments[j]
)
df_a = series[inst_a][["tstamp", "price"]].rename(
columns={"price": "price_a"}
)
df_b = series[inst_b][["tstamp", "price"]].rename(
columns={"price": "price_b"}
)
merged = pd.merge(df_a, df_b, on="tstamp", how="inner").sort_values(
"tstamp"
)
stats = self._compute_stats(inst_a, inst_b, pair_name, merged)
if stats:
results[pair_name] = stats
return self._rank(results)
def _compute_stats(
self,
inst_a: ExchangeInstrument,
inst_b: ExchangeInstrument,
pair_name: str,
merged: pd.DataFrame,
) -> Optional[PairStats]:
if len(merged) < 2:
return None
px_a = merged["price_a"].astype(float)
px_b = merged["price_b"].astype(float)
std_a = float(px_a.std())
std_b = float(px_b.std())
if std_a == 0 or std_b == 0:
return None
z_a = (px_a - float(px_a.mean())) / std_a
z_b = (px_b - float(px_b.mean())) / std_b
p_eg: Optional[float]
p_adf: Optional[float]
p_j: Optional[float]
trace_stat: Optional[float]
try:
p_eg = float(coint(z_a, z_b)[1])
except Exception as exc:
Log.warning(
f"{self.fname()}: EG failed for {inst_a.details_short()}/{inst_b.details_short()}: {exc}"
)
p_eg = None
try:
spread = z_a - z_b
p_adf = float(adfuller(spread, maxlag=1, regression="c")[1])
except Exception as exc:
Log.warning(
f"{self.fname()}: ADF failed for {inst_a.details_short()}/{inst_b.details_short()}: {exc}"
)
p_adf = None
try:
data = np.column_stack([z_a, z_b])
res = coint_johansen(data, det_order=0, k_ar_diff=1)
trace_stat = float(res.lr1[0])
cv10, cv5, cv1 = res.cvt[0]
if trace_stat > cv1:
p_j = 0.01
elif trace_stat > cv5:
p_j = 0.05
elif trace_stat > cv10:
p_j = 0.10
else:
p_j = 1.0
except Exception as exc:
Log.warning(
f"{self.fname()}: Johansen failed for {inst_a.details_short()}/{inst_b.details_short()}: {exc}"
)
p_j = None
trace_stat = None
return PairStats(
pair_name_=pair_name,
instrument_a_=inst_a,
instrument_b_=inst_b,
pvalue_eg_=p_eg,
pvalue_adf_=p_adf,
pvalue_j_=p_j,
trace_stat_j_=trace_stat,
)
def _rank(self, results: Dict[str, PairStats]) -> Dict[str, PairStats]:
ranked = list(results.values())
self._assign_ranks(ranked, key=lambda r: r.pvalue_eg_, attr="rank_eg_")
self._assign_ranks(ranked, key=lambda r: r.pvalue_adf_, attr="rank_adf_")
self._assign_ranks(ranked, key=lambda r: r.pvalue_j_, attr="rank_j_")
for res in ranked:
res.composite_rank_ = res.rank_eg_ + res.rank_adf_ + res.rank_j_
ranked.sort(key=lambda r: r.composite_rank_)
return {res.pair_name_: res for res in ranked}
@staticmethod
def _normalized_pair(
inst_a: ExchangeInstrument, inst_b: ExchangeInstrument
) -> Tuple[ExchangeInstrument, ExchangeInstrument, str]:
inst_a_id = PairAnalyzer._pair_label(inst_a.instrument_id())
inst_b_id = PairAnalyzer._pair_label(inst_b.instrument_id())
if inst_a_id <= inst_b_id:
return inst_a, inst_b, f"{inst_a_id}<->{inst_b_id}"
return inst_b, inst_a, f"{inst_b_id}<->{inst_a_id}"
@staticmethod
def _pair_label(instrument_id: str) -> str:
if instrument_id.startswith("PAIR-"):
return instrument_id[len("PAIR-") :]
return instrument_id
@staticmethod
def _assign_ranks(results: List[PairStats], key, attr: str) -> None:
values = [key(r) for r in results]
sorted_vals = sorted([v for v in values if v is not None])
for res in results:
val = key(res)
if val is None:
setattr(res, attr, len(sorted_vals) + 1)
continue
rank = 1 + sum(1 for v in sorted_vals if v < val)
setattr(res, attr, rank)
class PairSelectionEngine(NamedObject):
config_: object
instruments_: List[ExchangeInstrument]
price_field_: str
fetcher_: DataFetcher
quality_: QualityChecker
analyzer_: PairAnalyzer
interval_sec_: int
history_depth_sec_: int
data_quality_cache_: List[InstrumentQuality]
pair_results_cache_: Dict[str, PairStats]
def __init__(
self,
config: Config,
instruments: List[ExchangeInstrument],
price_field: str,
) -> None:
self.config_ = config
self.instruments_ = instruments
self.price_field_ = price_field
interval_sec = int(config.get_value("interval_sec", 0))
history_depth_sec = int(config.get_value("history_depth_hours", 0)) * SecPerHour
base_url = config.get_value("cvtt_base_url", None)
assert interval_sec > 0, "interval_sec must be > 0"
assert history_depth_sec > 0, "history_depth_sec must be > 0"
assert base_url, "cvtt_base_url must be set"
self.fetcher_ = DataFetcher(
base_url=base_url,
interval_sec=interval_sec,
history_depth_sec=history_depth_sec,
)
self.quality_ = QualityChecker(interval_sec=interval_sec)
self.analyzer_ = PairAnalyzer(
price_field=price_field, interval_sec=interval_sec
)
self.interval_sec_ = interval_sec
self.history_depth_sec_ = history_depth_sec
self.data_quality_cache_ = []
self.pair_results_cache_ = {}
async def run_once(self) -> None:
quality_results: List[InstrumentQuality] = []
price_series: Dict[ExchangeInstrument, pd.DataFrame] = {}
for inst in self.instruments_:
exch_acct = inst.user_data_.get("exch_acct") or inst.exchange_id_
aggr = self.fetcher_.fetch(exch_acct=exch_acct, inst=inst)
q = self.quality_.evaluate(inst, aggr)
quality_results.append(q)
if q.status_ != "PASS":
continue
df = self._to_dataframe(aggr, inst)
if len(df) > 0:
price_series[inst] = df
self.data_quality_cache_ = quality_results
self.pair_results_cache_ = self.analyzer_.analyze(price_series)
def _to_dataframe(
self, aggr: List[MdTradesAggregate], inst: ExchangeInstrument
) -> pd.DataFrame:
rows: List[Dict[str, Any]] = []
for item in aggr:
rows.append(
{
"tstamp": pd.to_datetime(item.aggr_time_ns_, unit="ns", utc=True),
"price": self._extract_price(item, inst),
"num_trades": item.num_trades_,
}
)
df = pd.DataFrame(rows)
return df.sort_values("tstamp").reset_index(drop=True)
def _extract_price(
self, aggr: MdTradesAggregate, inst: ExchangeInstrument
) -> float:
price_field = self.price_field_
# MdTradesAggregate inherits hist bar with fields open_, high_, low_, close_, vwap_
field_map = {
"open": aggr.open_,
"high": aggr.high_,
"low": aggr.low_,
"close": aggr.close_,
"vwap": aggr.vwap_,
}
raw = field_map.get(price_field, aggr.close_)
return inst.get_price(raw)
def sleep_seconds_until_next_cycle(self) -> float:
now_ns = current_nanoseconds()
interval_ns = self.interval_sec_ * NanoPerSec
next_boundary = (now_ns // interval_ns + 1) * interval_ns
return max(0.0, (next_boundary - now_ns) / NanoPerSec)
def quality_dicts(self) -> List[Dict[str, Any]]:
res: List[Dict[str, Any]] = []
for q in self.data_quality_cache_:
res.append(
{
"instrument": q.instrument_.instrument_id(),
"record_count": q.record_count_,
"latest_tstamp": (
q.latest_tstamp_.isoformat() if q.latest_tstamp_ else None
),
"status": q.status_,
"reason": q.reason_,
}
)
return res
def pair_dicts(self) -> Dict[str, Dict[str, Any]]:
return {
pair_name: stats.as_dict()
for pair_name, stats in self.pair_results_cache_.items()
}
class PairSelector(NamedObject):
instruments_: List[ExchangeInstrument]
engine_: PairSelectionEngine
rest_service_: RestService
def __init__(self) -> None:
App.instance().add_cmdline_arg("--oneshot", action="store_true", default=False)
App.instance().add_call(App.Stage.Config, self._on_config())
App.instance().add_call(App.Stage.Run, self.run())
async def _on_config(self) -> None:
cfg = CvttAppConfig.instance()
self.instruments_ = self._load_instruments(cfg)
price_field = cfg.get_value("model/stat_model_price", "close")
self.engine_ = PairSelectionEngine(
config=cfg,
instruments=self.instruments_,
price_field=price_field,
)
self.rest_service_ = RestService(config_key="/api/REST")
self.rest_service_.add_handler("GET", "/data_quality", self._on_data_quality)
self.rest_service_.add_handler(
"GET", "/pair_selection", self._on_pair_selection
)
def _load_instruments(self, cfg: CvttAppConfig) -> List[ExchangeInstrument]:
instruments_cfg = cfg.get_value("instruments", [])
instruments: List[ExchangeInstrument] = []
assert len(instruments_cfg) >= 2, "at least two instruments required"
for item in instruments_cfg:
if isinstance(item, str):
parts = item.split(":", 1)
if len(parts) != 2:
raise ValueError(f"invalid instrument format: {item}")
exch_acct, instrument_id = parts
elif isinstance(item, dict):
exch_acct = item.get("exch_acct", "")
instrument_id = item.get("instrument_id", "")
if not exch_acct or not instrument_id:
raise ValueError(f"invalid instrument config: {item}")
else:
raise ValueError(f"unsupported instrument entry: {item}")
exch_inst = ExchangeAccounts.instance().get_exchange_instrument(
exch_acct=exch_acct, instrument_id=instrument_id
)
assert (
exch_inst is not None
), f"no ExchangeInstrument for {exch_acct}:{instrument_id}"
exch_inst.user_data_["exch_acct"] = exch_acct
instruments.append(exch_inst)
return instruments
async def run(self) -> None:
oneshot = App.instance().get_argument("oneshot", False)
while True:
await self.engine_.run_once()
if oneshot:
break
sleep_for = self.engine_.sleep_seconds_until_next_cycle()
await asyncio.sleep(sleep_for)
async def _on_data_quality(self, request: web.Request) -> web.Response:
fmt = request.query.get("format", "html").lower()
quality = self.engine_.quality_dicts()
if fmt == "json":
return web.json_response(quality)
return web.Response(
text=HtmlRenderer.render_data_quality(quality), content_type="text/html"
)
async def _on_pair_selection(self, request: web.Request) -> web.Response:
fmt = request.query.get("format", "html").lower()
pairs = self.engine_.pair_dicts()
if fmt == "json":
return web.json_response(pairs)
return web.Response(
text=HtmlRenderer.render_pairs(pairs), content_type="text/html"
)
if __name__ == "__main__":
App()
CvttAppConfig()
PairSelector()
App.instance().run()