pairs_trading/apps/pair_selector/pair_selector_engine.py.md
2026-02-03 19:35:42 +00:00

14 KiB

from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from statsmodels.tsa.stattools import adfuller, coint
from statsmodels.tsa.vector_ar.vecm import coint_johansen
from statsmodels.tsa.vector_ar.vecm import coint_johansen  # type: ignore
# ---
from cvttpy_tools.base import NamedObject
from cvttpy_tools.config import Config
from cvttpy_tools.logger import Log
from cvttpy_tools.timeutils import NanoPerSec, SecPerHour, current_nanoseconds
from cvttpy_tools.web.rest_client import RESTSender
# ---
from cvttpy_trading.trading.instrument import ExchangeInstrument
from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate, MdSummary


@dataclass
class InstrumentQuality(NamedObject):
    instrument_: ExchangeInstrument
    record_count_: int
    latest_tstamp_: Optional[pd.Timestamp]
    status_: str
    reason_: str


@dataclass
class PairStats(NamedObject):
    instrument_a_: ExchangeInstrument
    instrument_b_: ExchangeInstrument
    pvalue_eg_: Optional[float]
    pvalue_adf_: Optional[float]
    pvalue_j_: Optional[float]
    trace_stat_j_: Optional[float]
    rank_eg_: int = 0
    rank_adf_: int = 0
    rank_j_: int = 0
    composite_rank_: int = 0

    def as_dict(self) -> Dict[str, Any]:
        return {
            "instrument_a": self.instrument_a_.instrument_id(),
            "instrument_b": self.instrument_b_.instrument_id(),
            "pvalue_eg": self.pvalue_eg_,
            "pvalue_adf": self.pvalue_adf_,
            "pvalue_j": self.pvalue_j_,
            "trace_stat_j": self.trace_stat_j_,
            "rank_eg": self.rank_eg_,
            "rank_adf": self.rank_adf_,
            "rank_j": self.rank_j_,
            "composite_rank": self.composite_rank_,
        }


class DataFetcher(NamedObject):
    sender_: RESTSender
    interval_sec_: int
    history_depth_sec_: int

    def __init__(
        self,
        base_url: str,
        interval_sec: int,
        history_depth_sec: int,
    ) -> None:
        self.sender_ = RESTSender(base_url=base_url)
        self.interval_sec_ = interval_sec
        self.history_depth_sec_ = history_depth_sec

    def fetch(self, exch_acct: str, inst: ExchangeInstrument) -> List[MdTradesAggregate]:
        rqst_data = {
            "exch_acct": exch_acct,
            "instrument_id": inst.instrument_id(),
            "interval_sec": self.interval_sec_,
            "history_depth_sec": self.history_depth_sec_,
        }
        response = self.sender_.send_post(endpoint="md_summary", post_body=rqst_data)
        if response.status_code not in (200, 201):
            Log.error(
                f"{self.fname()}: error {response.status_code} for {inst.details_short()}: {response.text}")
            return []
        mdsums: List[MdSummary] = MdSummary.from_REST_response(response=response)
        return [
            mdsum.create_md_trades_aggregate(
                exch_acct=exch_acct, exch_inst=inst, interval_sec=self.interval_sec_
            )
            for mdsum in mdsums
        ]


class QualityChecker(NamedObject):
    interval_sec_: int

    def __init__(self, interval_sec: int) -> None:
        self.interval_sec_ = interval_sec

    def evaluate(self, inst: ExchangeInstrument, aggr: List[MdTradesAggregate]) -> InstrumentQuality:
        if len(aggr) == 0:
            return InstrumentQuality(
                instrument_=inst,
                record_count_=0,
                latest_tstamp_=None,
                status_="FAIL",
                reason_="no records",
            )

        aggr_sorted = sorted(aggr, key=lambda a: a.aggr_time_ns_)

        latest_ts = pd.to_datetime(aggr_sorted[-1].aggr_time_ns_, unit="ns", utc=True)
        now_ts = pd.Timestamp.utcnow()
        recency_cutoff = now_ts - pd.Timedelta(seconds=2 * self.interval_sec_)
        if latest_ts <= recency_cutoff:
            return InstrumentQuality(
                instrument_=inst,
                record_count_=len(aggr_sorted),
                latest_tstamp_=latest_ts,
                status_="FAIL",
                reason_=f"stale: latest {latest_ts} <= cutoff {recency_cutoff}",
            )

        gaps_ok, reason = self._check_gaps(aggr_sorted)
        status = "PASS" if gaps_ok else "FAIL"
        return InstrumentQuality(
            instrument_=inst,
            record_count_=len(aggr_sorted),
            latest_tstamp_=latest_ts,
            status_=status,
            reason_=reason,
        )

    def _check_gaps(self, aggr: List[MdTradesAggregate]) -> Tuple[bool, str]:
        NUM_TRADES_THRESHOLD = 50
        if len(aggr) < 2:
            return True, "ok"

        interval_ns = self.interval_sec_ * NanoPerSec
        for idx in range(1, len(aggr)):
            prev = aggr[idx - 1]
            curr = aggr[idx]
            delta = curr.aggr_time_ns_ - prev.aggr_time_ns_
            missing_intervals = int(delta // interval_ns) - 1
            if missing_intervals <= 0:
                continue
            
            prev_nt = prev.num_trades_
            next_nt = curr.num_trades_
            estimate = self._approximate_num_trades(prev_nt, next_nt)
            if estimate > NUM_TRADES_THRESHOLD:
                return False, (
                    f"gap of {missing_intervals} interval(s), est num_trades={estimate} > {NUM_TRADES_THRESHOLD}"
                )
        return True, "ok"

    @staticmethod
    def _approximate_num_trades(prev_nt: int, next_nt: int) -> float:
        if prev_nt is None and next_nt is None:
            return 0.0
        if prev_nt is None:
            return float(next_nt)
        if next_nt is None:
            return float(prev_nt)
        return (prev_nt + next_nt) / 2.0


class PairAnalyzer(NamedObject):
    price_field_: str
    interval_sec_: int

    def __init__(self, price_field: str, interval_sec: int) -> None:
        self.price_field_ = price_field
        self.interval_sec_ = interval_sec

    def analyze(self, series: Dict[ExchangeInstrument, pd.DataFrame]) -> List[PairStats]:
        instruments = list(series.keys())
        results: List[PairStats] = []
        for i in range(len(instruments)):
            for j in range(i + 1, len(instruments)):
                inst_a = instruments[i]
                inst_b = instruments[j]
                df_a = series[inst_a][["tstamp", "price"]].rename(
                    columns={"price": "price_a"}
                )
                df_b = series[inst_b][["tstamp", "price"]].rename(
                    columns={"price": "price_b"}
                )
                merged = pd.merge(df_a, df_b, on="tstamp", how="inner").sort_values(
                    "tstamp"
                )
                stats = self._compute_stats(inst_a, inst_b, merged)
                if stats:
                    results.append(stats)
        self._rank(results)
        return results

    def _compute_stats(
        self,
        inst_a: ExchangeInstrument,
        inst_b: ExchangeInstrument,
        merged: pd.DataFrame,
    ) -> Optional[PairStats]:
        if len(merged) < 2:
            return None
        px_a = merged["price_a"].astype(float)
        px_b = merged["price_b"].astype(float)

        std_a = float(px_a.std())
        std_b = float(px_b.std())
        if std_a == 0 or std_b == 0:
            return None

        z_a = (px_a - float(px_a.mean())) / std_a
        z_b = (px_b - float(px_b.mean())) / std_b

        p_eg: Optional[float]
        p_adf: Optional[float]
        p_j: Optional[float]
        trace_stat: Optional[float]

        try:
            p_eg = float(coint(z_a, z_b)[1])
        except Exception as exc:
            Log.warning(f"{self.fname()}: EG failed for {inst_a.details_short()}/{inst_b.details_short()}: {exc}")
            p_eg = None

        try:
            spread = z_a - z_b
            p_adf = float(adfuller(spread, maxlag=1, regression="c")[1])
        except Exception as exc:
            Log.warning(f"{self.fname()}: ADF failed for {inst_a.details_short()}/{inst_b.details_short()}: {exc}")
            p_adf = None

        try:
            data = np.column_stack([z_a, z_b])
            res = coint_johansen(data, det_order=0, k_ar_diff=1)
            trace_stat = float(res.lr1[0])
            cv10, cv5, cv1 = res.cvt[0]
            if trace_stat > cv1:
                p_j = 0.01
            elif trace_stat > cv5:
                p_j = 0.05
            elif trace_stat > cv10:
                p_j = 0.10
            else:
                p_j = 1.0
        except Exception as exc:
            Log.warning(f"{self.fname()}: Johansen failed for {inst_a.details_short()}/{inst_b.details_short()}: {exc}")
            p_j = None
            trace_stat = None

        return PairStats(
            instrument_a_=inst_a,
            instrument_b_=inst_b,
            pvalue_eg_=p_eg,
            pvalue_adf_=p_adf,
            pvalue_j_=p_j,
            trace_stat_j_=trace_stat,
        )

    def _rank(self, results: List[PairStats]) -> None:
        self._assign_ranks(results, key=lambda r: r.pvalue_eg_, attr="rank_eg_")
        self._assign_ranks(results, key=lambda r: r.pvalue_adf_, attr="rank_adf_")
        self._assign_ranks(results, key=lambda r: r.pvalue_j_, attr="rank_j_")
        for res in results:
            res.composite_rank_ = res.rank_eg_ + res.rank_adf_ + res.rank_j_
        results.sort(key=lambda r: r.composite_rank_)

    @staticmethod
    def _assign_ranks(
        results: List[PairStats], key, attr: str
    ) -> None:
        values = [key(r) for r in results]
        sorted_vals = sorted([v for v in values if v is not None])
        for res in results:
            val = key(res)
            if val is None:
                setattr(res, attr, len(sorted_vals) + 1)
                continue
            rank = 1 + sum(1 for v in sorted_vals if v < val)
            setattr(res, attr, rank)


class PairSelectionEngine(NamedObject):
    config_: object
    instruments_: List[ExchangeInstrument]
    price_field_: str
    fetcher_: DataFetcher
    quality_: QualityChecker
    analyzer_: PairAnalyzer
    interval_sec_: int
    history_depth_sec_: int
    data_quality_cache_: List[InstrumentQuality]
    pair_results_cache_: List[PairStats]

    def __init__(
        self,
        config: Config,
        instruments: List[ExchangeInstrument],
        price_field: str,
    ) -> None:
        self.config_ = config
        self.instruments_ = instruments
        self.price_field_ = price_field

        interval_sec = int(config.get_value("interval_sec", 0))
        history_depth_sec = int(config.get_value("history_depth_hours", 0)) * SecPerHour
        base_url = config.get_value("cvtt_base_url", None)
        assert interval_sec > 0, "interval_sec must be > 0"
        assert history_depth_sec > 0, "history_depth_sec must be > 0"
        assert base_url, "cvtt_base_url must be set"

        self.fetcher_ = DataFetcher(
            base_url=base_url,
            interval_sec=interval_sec,
            history_depth_sec=history_depth_sec,
        )
        self.quality_ = QualityChecker(interval_sec=interval_sec)
        self.analyzer_ = PairAnalyzer(price_field=price_field, interval_sec=interval_sec)

        self.interval_sec_ = interval_sec
        self.history_depth_sec_ = history_depth_sec

        self.data_quality_cache_ = []
        self.pair_results_cache_ = []

    async def run_once(self) -> None:
        quality_results: List[InstrumentQuality] = []
        price_series: Dict[ExchangeInstrument, pd.DataFrame] = {}

        for inst in self.instruments_:
            exch_acct = inst.user_data_.get("exch_acct") or inst.exchange_id_
            aggr = self.fetcher_.fetch(exch_acct=exch_acct, inst=inst)
            q = self.quality_.evaluate(inst, aggr)
            quality_results.append(q)
            if q.status_ != "PASS":
                continue
            df = self._to_dataframe(aggr, inst)
            if len(df) > 0:
                price_series[inst] = df
        self.data_quality_cache_ = quality_results
        self.pair_results_cache_ = self.analyzer_.analyze(price_series)

    def _to_dataframe(self, aggr: List[MdTradesAggregate], inst: ExchangeInstrument) -> pd.DataFrame:
        rows: List[Dict[str, Any]] = []
        for item in aggr:
            rows.append(
                {
                    "tstamp": pd.to_datetime(item.aggr_time_ns_, unit="ns", utc=True),
                    "price": self._extract_price(item, inst),
                    "num_trades": item.num_trades_,
                }
            )
        df = pd.DataFrame(rows)
        return df.sort_values("tstamp").reset_index(drop=True)

    def _extract_price(self, aggr: MdTradesAggregate, inst: ExchangeInstrument) -> float:
        price_field = self.price_field_
        # MdTradesAggregate inherits hist bar with fields open_, high_, low_, close_, vwap_
        field_map = {
            "open": aggr.open_,
            "high": aggr.high_,
            "low": aggr.low_,
            "close": aggr.close_,
            "vwap": aggr.vwap_,
        }
        raw = field_map.get(price_field, aggr.close_)
        return inst.get_price(raw)

    def sleep_seconds_until_next_cycle(self) -> float:
        now_ns = current_nanoseconds()
        interval_ns = self.interval_sec_ * NanoPerSec
        next_boundary = (now_ns // interval_ns + 1) * interval_ns
        return max(0.0, (next_boundary - now_ns) / NanoPerSec)

    def quality_dicts(self) -> List[Dict[str, Any]]:
        res: List[Dict[str, Any]] = []
        for q in self.data_quality_cache_:
            res.append(
                {
                    "instrument": q.instrument_.instrument_id(),
                    "record_count": q.record_count_,
                    "latest_tstamp": q.latest_tstamp_.isoformat() if q.latest_tstamp_ else None,
                    "status": q.status_,
                    "reason": q.reason_,
                }
            )
        return res

    def pair_dicts(self) -> List[Dict[str, Any]]:
        return [p.as_dict() for p in self.pair_results_cache_]