224 lines
8.7 KiB
Python
224 lines
8.7 KiB
Python
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import pandas as pd
|
|
|
|
# ---
|
|
from cvttpy_tools.base import NamedObject
|
|
from cvttpy_tools.config import Config
|
|
from cvttpy_tools.settings.cvtt_types import JsonDictT
|
|
|
|
# ---
|
|
from cvttpy_trading.trading.mkt_data.md_summary import MdTradesAggregate
|
|
from cvttpy_trading.trading.instrument import ExchangeInstrument
|
|
# ---
|
|
from pairs_trading.lib.tools.data_loader import load_market_data
|
|
|
|
|
|
class PtMarketData(NamedObject, ABC):
|
|
config_: Config
|
|
origin_mkt_data_df_: pd.DataFrame
|
|
market_data_df_: pd.DataFrame
|
|
stat_model_price_: str
|
|
instruments_: List[ExchangeInstrument]
|
|
symbol_a_: str
|
|
symbol_b_: str
|
|
|
|
def __init__(self, config: Config, instruments: List[ExchangeInstrument]):
|
|
self.config_ = config
|
|
self.origin_mkt_data_df_ = pd.DataFrame()
|
|
self.market_data_df_ = pd.DataFrame()
|
|
self.stat_model_price_ = self.config_.get_value("stat_model_price")
|
|
|
|
self.instruments_ = instruments
|
|
assert len(self.instruments_) > 0, "No instruments found in config"
|
|
self.symbol_a_ = self.instruments_[0].instrument_id().split("-", 1)[1]
|
|
self.symbol_b_ = self.instruments_[1].instrument_id().split("-", 1)[1]
|
|
|
|
@abstractmethod
|
|
def md_columns(self) -> List[str]: ...
|
|
|
|
@abstractmethod
|
|
def rename_columns(self, symbol_df: pd.DataFrame) -> pd.DataFrame: ...
|
|
|
|
@abstractmethod
|
|
def tranform_df_target_colnames(self) -> List[str]: ...
|
|
|
|
def set_market_data(self) -> None:
|
|
self.market_data_df_ = pd.DataFrame(
|
|
self._transform_dataframe(self.origin_mkt_data_df_)[
|
|
["tstamp"] + self.tranform_df_target_colnames()
|
|
]
|
|
)
|
|
|
|
self.market_data_df_ = self.market_data_df_.dropna().reset_index(drop=True)
|
|
self.market_data_df_["tstamp"] = pd.to_datetime(self.market_data_df_["tstamp"])
|
|
self.market_data_df_ = self.market_data_df_.sort_values("tstamp")
|
|
|
|
def colnames(self) -> List[str]:
|
|
return [
|
|
f"{self.stat_model_price_}_{self.symbol_a_}",
|
|
f"{self.stat_model_price_}_{self.symbol_b_}",
|
|
]
|
|
|
|
def _transform_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
df_selected: pd.DataFrame = pd.DataFrame(df[self.md_columns()])
|
|
result_df = (
|
|
pd.DataFrame(df_selected["tstamp"]).drop_duplicates().reset_index(drop=True)
|
|
)
|
|
|
|
# For each unique symbol, add a corresponding stat_model_price column
|
|
symbols = df_selected["symbol"].unique()
|
|
|
|
for symbol in symbols:
|
|
# Filter rows for this symbol
|
|
df_symbol = df_selected[df_selected["symbol"] == symbol].reset_index(
|
|
drop=True
|
|
)
|
|
# Create column name like "close-COIN"
|
|
temp_df: pd.DataFrame = self.rename_columns(df_symbol)
|
|
# Join with our result dataframe
|
|
result_df = pd.merge(result_df, temp_df, on="tstamp", how="left")
|
|
result_df = result_df.reset_index(
|
|
drop=True
|
|
) # do not dropna() since irrelevant symbol would affect dataset
|
|
|
|
return result_df.dropna()
|
|
|
|
class ResearchMarketData(PtMarketData):
|
|
current_index_: int
|
|
is_execution_price_: bool
|
|
|
|
def __init__(self, config: Config, instruments: List[ExchangeInstrument]):
|
|
super().__init__(config, instruments)
|
|
self.current_index_ = 0
|
|
self.is_execution_price_ = self.config_.key_exists("execution_price")
|
|
if self.is_execution_price_:
|
|
self.execution_price_column_ = self.config_.get_value("execution_price")["column"]
|
|
self.execution_price_shift_ = self.config_.get_value("execution_price")["shift"]
|
|
else:
|
|
self.execution_price_column_ = None
|
|
self.execution_price_shift_ = 0
|
|
|
|
def has_next(self) -> bool:
|
|
return self.current_index_ < len(self.market_data_df_)
|
|
|
|
def get_next(self) -> pd.Series:
|
|
result = self.market_data_df_.iloc[self.current_index_]
|
|
self.current_index_ += 1
|
|
return result
|
|
|
|
def load(self) -> None:
|
|
datafiles: List[str] = self.config_.get_value("datafiles", [])
|
|
assert len(datafiles) > 0, "No datafiles found in config"
|
|
|
|
extra_minutes: int = self.execution_price_shift_
|
|
|
|
for datafile in datafiles:
|
|
md_df = load_market_data(
|
|
datafile=datafile,
|
|
instruments=self.instruments_,
|
|
db_table_name=self.config_.get_value("market_data_loading")[
|
|
self.instruments_[0].user_data_.get("instrument_type", "?instrument_type?")
|
|
]["db_table_name"],
|
|
trading_hours=self.config_.get_value("trading_hours"),
|
|
extra_minutes=extra_minutes,
|
|
)
|
|
self.origin_mkt_data_df_ = pd.concat([self.origin_mkt_data_df_, md_df])
|
|
|
|
self.origin_mkt_data_df_ = self.origin_mkt_data_df_.sort_values(by="tstamp")
|
|
self.origin_mkt_data_df_ = self.origin_mkt_data_df_.dropna().reset_index(
|
|
drop=True
|
|
)
|
|
self.set_market_data()
|
|
self._set_execution_price_data()
|
|
|
|
def _set_execution_price_data(self) -> None:
|
|
if not self.is_execution_price_:
|
|
return
|
|
if not self.config_.key_exists("execution_price"):
|
|
self.market_data_df_[f"exec_price_{self.symbol_a_}"] = self.market_data_df_[
|
|
f"{self.stat_model_price_}_{self.symbol_a_}"
|
|
]
|
|
self.market_data_df_[f"exec_price_{self.symbol_b_}"] = self.market_data_df_[
|
|
f"{self.stat_model_price_}_{self.symbol_b_}"
|
|
]
|
|
return
|
|
execution_price_column = self.config_.get_value("execution_price")["column"]
|
|
execution_price_shift = self.config_.get_value("execution_price")["shift"]
|
|
self.market_data_df_[f"exec_price_{self.symbol_a_}"] = self.market_data_df_[
|
|
f"{execution_price_column}_{self.symbol_a_}"
|
|
].shift(-execution_price_shift)
|
|
self.market_data_df_[f"exec_price_{self.symbol_b_}"] = self.market_data_df_[
|
|
f"{execution_price_column}_{self.symbol_b_}"
|
|
].shift(-execution_price_shift)
|
|
self.market_data_df_ = self.market_data_df_.dropna().reset_index(drop=True)
|
|
|
|
def md_columns(self) -> List[str]:
|
|
# @abstractmethod
|
|
if self.is_execution_price_:
|
|
return ["tstamp", "symbol", self.stat_model_price_, self.execution_price_column_]
|
|
else:
|
|
return ["tstamp", "symbol", self.stat_model_price_]
|
|
|
|
def rename_columns(self, selected_symbol_df: pd.DataFrame) -> pd.DataFrame:
|
|
# @abstractmethod
|
|
symbol = selected_symbol_df.iloc[0]["symbol"]
|
|
new_price_column = f"{self.stat_model_price_}_{symbol}"
|
|
if self.is_execution_price_:
|
|
new_execution_price_column = f"{self.execution_price_column_}_{symbol}"
|
|
|
|
# Create temporary dataframe with timestamp and price
|
|
temp_df = pd.DataFrame(
|
|
{
|
|
"tstamp": selected_symbol_df["tstamp"],
|
|
new_price_column: selected_symbol_df[self.stat_model_price_],
|
|
new_execution_price_column: selected_symbol_df[self.execution_price_column_],
|
|
}
|
|
)
|
|
else:
|
|
temp_df = pd.DataFrame(
|
|
{
|
|
"tstamp": selected_symbol_df["tstamp"],
|
|
new_price_column: selected_symbol_df[self.stat_model_price_],
|
|
}
|
|
)
|
|
return temp_df
|
|
|
|
def tranform_df_target_colnames(self):
|
|
# @abstractmethod
|
|
return self.colnames() + self.orig_exec_prices_colnames()
|
|
|
|
def orig_exec_prices_colnames(self) -> List[str]:
|
|
return [
|
|
f"{self.execution_price_column_}_{self.symbol_a_}",
|
|
f"{self.execution_price_column_}_{self.symbol_b_}",
|
|
] if self.is_execution_price_ else []
|
|
|
|
class LiveMarketData(PtMarketData):
|
|
|
|
def __init__(self, config: Config, instruments: List[ExchangeInstrument]):
|
|
super().__init__(config, instruments)
|
|
|
|
def md_columns(self) -> List[str]:
|
|
# @abstractmethod
|
|
return ["tstamp", "symbol", self.stat_model_price_]
|
|
|
|
def rename_columns(self, selected_symbol_df: pd.DataFrame) -> pd.DataFrame:
|
|
# @abstractmethod
|
|
symbol = selected_symbol_df.iloc[0]["symbol"]
|
|
new_price_column = f"{self.stat_model_price_}_{symbol}"
|
|
temp_df = pd.DataFrame(
|
|
{
|
|
"tstamp": selected_symbol_df["tstamp"],
|
|
new_price_column: selected_symbol_df[self.stat_model_price_],
|
|
}
|
|
)
|
|
return temp_df
|
|
|
|
def tranform_df_target_colnames(self):
|
|
# @abstractmethod
|
|
return self.colnames()
|