pairs_trading/lib/pt_strategy/pt_market_data.py
Oleg Sheynin b196863a34 progress
2026-01-11 13:33:58 +00:00

224 lines
8.7 KiB
Python

from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
import pandas as pd
# ---
from cvttpy_tools.base import NamedObject
from cvttpy_tools.config import Config
from cvttpy_tools.settings.cvtt_types import JsonDictT
# ---
from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate
from cvttpy_trading.trading.instrument import ExchangeInstrument
# ---
from pairs_trading.lib.tools.data_loader import load_market_data
class PtMarketData(NamedObject, ABC):
config_: Config
origin_mkt_data_df_: pd.DataFrame
market_data_df_: pd.DataFrame
stat_model_price_: str
instruments_: List[ExchangeInstrument]
symbol_a_: str
symbol_b_: str
def __init__(self, config: Config, instruments: List[ExchangeInstrument]):
self.config_ = config
self.origin_mkt_data_df_ = pd.DataFrame()
self.market_data_df_ = pd.DataFrame()
self.stat_model_price_ = self.config_.get_value("stat_model_price")
self.instruments_ = instruments
assert len(self.instruments_) > 0, "No instruments found in config"
self.symbol_a_ = self.instruments_[0].instrument_id().split("-", 1)[1]
self.symbol_b_ = self.instruments_[1].instrument_id().split("-", 1)[1]
@abstractmethod
def md_columns(self) -> List[str]: ...
@abstractmethod
def rename_columns(self, symbol_df: pd.DataFrame) -> pd.DataFrame: ...
@abstractmethod
def tranform_df_target_colnames(self) -> List[str]: ...
def set_market_data(self) -> None:
self.market_data_df_ = pd.DataFrame(
self._transform_dataframe(self.origin_mkt_data_df_)[
["tstamp"] + self.tranform_df_target_colnames()
]
)
self.market_data_df_ = self.market_data_df_.dropna().reset_index(drop=True)
self.market_data_df_["tstamp"] = pd.to_datetime(self.market_data_df_["tstamp"])
self.market_data_df_ = self.market_data_df_.sort_values("tstamp")
def colnames(self) -> List[str]:
return [
f"{self.stat_model_price_}_{self.symbol_a_}",
f"{self.stat_model_price_}_{self.symbol_b_}",
]
def _transform_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
df_selected: pd.DataFrame = pd.DataFrame(df[self.md_columns()])
result_df = (
pd.DataFrame(df_selected["tstamp"]).drop_duplicates().reset_index(drop=True)
)
# For each unique symbol, add a corresponding stat_model_price column
symbols = df_selected["symbol"].unique()
for symbol in symbols:
# Filter rows for this symbol
df_symbol = df_selected[df_selected["symbol"] == symbol].reset_index(
drop=True
)
# Create column name like "close-COIN"
temp_df: pd.DataFrame = self.rename_columns(df_symbol)
# Join with our result dataframe
result_df = pd.merge(result_df, temp_df, on="tstamp", how="left")
result_df = result_df.reset_index(
drop=True
) # do not dropna() since irrelevant symbol would affect dataset
return result_df.dropna()
class ResearchMarketData(PtMarketData):
current_index_: int
is_execution_price_: bool
def __init__(self, config: Config, instruments: List[ExchangeInstrument]):
super().__init__(config, instruments)
self.current_index_ = 0
self.is_execution_price_ = self.config_.key_exists("execution_price")
if self.is_execution_price_:
self.execution_price_column_ = self.config_.get_value("execution_price")["column"]
self.execution_price_shift_ = self.config_.get_value("execution_price")["shift"]
else:
self.execution_price_column_ = None
self.execution_price_shift_ = 0
def has_next(self) -> bool:
return self.current_index_ < len(self.market_data_df_)
def get_next(self) -> pd.Series:
result = self.market_data_df_.iloc[self.current_index_]
self.current_index_ += 1
return result
def load(self) -> None:
datafiles: List[str] = self.config_.get_value("datafiles", [])
assert len(datafiles) > 0, "No datafiles found in config"
extra_minutes: int = self.execution_price_shift_
for datafile in datafiles:
md_df = load_market_data(
datafile=datafile,
instruments=self.instruments_,
db_table_name=self.config_.get_value("market_data_loading")[
self.instruments_[0].user_data_.get("instrument_type", "?instrument_type?")
]["db_table_name"],
trading_hours=self.config_.get_value("trading_hours"),
extra_minutes=extra_minutes,
)
self.origin_mkt_data_df_ = pd.concat([self.origin_mkt_data_df_, md_df])
self.origin_mkt_data_df_ = self.origin_mkt_data_df_.sort_values(by="tstamp")
self.origin_mkt_data_df_ = self.origin_mkt_data_df_.dropna().reset_index(
drop=True
)
self.set_market_data()
self._set_execution_price_data()
def _set_execution_price_data(self) -> None:
if not self.is_execution_price_:
return
if not self.config_.key_exists("execution_price"):
self.market_data_df_[f"exec_price_{self.symbol_a_}"] = self.market_data_df_[
f"{self.stat_model_price_}_{self.symbol_a_}"
]
self.market_data_df_[f"exec_price_{self.symbol_b_}"] = self.market_data_df_[
f"{self.stat_model_price_}_{self.symbol_b_}"
]
return
execution_price_column = self.config_.get_value("execution_price")["column"]
execution_price_shift = self.config_.get_value("execution_price")["shift"]
self.market_data_df_[f"exec_price_{self.symbol_a_}"] = self.market_data_df_[
f"{execution_price_column}_{self.symbol_a_}"
].shift(-execution_price_shift)
self.market_data_df_[f"exec_price_{self.symbol_b_}"] = self.market_data_df_[
f"{execution_price_column}_{self.symbol_b_}"
].shift(-execution_price_shift)
self.market_data_df_ = self.market_data_df_.dropna().reset_index(drop=True)
def md_columns(self) -> List[str]:
# @abstractmethod
if self.is_execution_price_:
return ["tstamp", "symbol", self.stat_model_price_, self.execution_price_column_]
else:
return ["tstamp", "symbol", self.stat_model_price_]
def rename_columns(self, selected_symbol_df: pd.DataFrame) -> pd.DataFrame:
# @abstractmethod
symbol = selected_symbol_df.iloc[0]["symbol"]
new_price_column = f"{self.stat_model_price_}_{symbol}"
if self.is_execution_price_:
new_execution_price_column = f"{self.execution_price_column_}_{symbol}"
# Create temporary dataframe with timestamp and price
temp_df = pd.DataFrame(
{
"tstamp": selected_symbol_df["tstamp"],
new_price_column: selected_symbol_df[self.stat_model_price_],
new_execution_price_column: selected_symbol_df[self.execution_price_column_],
}
)
else:
temp_df = pd.DataFrame(
{
"tstamp": selected_symbol_df["tstamp"],
new_price_column: selected_symbol_df[self.stat_model_price_],
}
)
return temp_df
def tranform_df_target_colnames(self):
# @abstractmethod
return self.colnames() + self.orig_exec_prices_colnames()
def orig_exec_prices_colnames(self) -> List[str]:
return [
f"{self.execution_price_column_}_{self.symbol_a_}",
f"{self.execution_price_column_}_{self.symbol_b_}",
] if self.is_execution_price_ else []
class LiveMarketData(PtMarketData):
def __init__(self, config: Config, instruments: List[ExchangeInstrument]):
super().__init__(config, instruments)
def md_columns(self) -> List[str]:
# @abstractmethod
return ["tstamp", "symbol", self.stat_model_price_]
def rename_columns(self, selected_symbol_df: pd.DataFrame) -> pd.DataFrame:
# @abstractmethod
symbol = selected_symbol_df.iloc[0]["symbol"]
new_price_column = f"{self.stat_model_price_}_{symbol}"
temp_df = pd.DataFrame(
{
"tstamp": selected_symbol_df["tstamp"],
new_price_column: selected_symbol_df[self.stat_model_price_],
}
)
return temp_df
def tranform_df_target_colnames(self):
# @abstractmethod
return self.colnames()