188 lines
7.1 KiB
Python
188 lines
7.1 KiB
Python
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List, Type
|
|
|
|
import pandas as pd
|
|
|
|
from tools.data_loader import load_market_data
|
|
|
|
|
|
class PtMarketData(ABC):
|
|
config_: Dict[str, Any]
|
|
origin_mkt_data_df_: pd.DataFrame
|
|
market_data_df_: pd.DataFrame
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
self.config_ = config
|
|
self.origin_mkt_data_df_ = pd.DataFrame()
|
|
|
|
@abstractmethod
|
|
def load(self) -> None:
|
|
...
|
|
|
|
|
|
@abstractmethod
|
|
def has_next(self) -> bool:
|
|
...
|
|
|
|
@abstractmethod
|
|
def get_next(self) -> pd.Series:
|
|
...
|
|
|
|
|
|
@staticmethod
|
|
def create(config: Dict[str, Any], md_class: Type[PtMarketData]) -> PtMarketData:
|
|
return md_class(config)
|
|
|
|
class ResearchMarketData(PtMarketData):
|
|
config_: Dict[str, Any]
|
|
current_index_: int
|
|
|
|
is_execution_price_: bool
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
super().__init__(config)
|
|
self.current_index_ = 0
|
|
self.is_execution_price_ = "execution_price" in self.config_
|
|
if self.is_execution_price_:
|
|
self.execution_price_column_ = self.config_["execution_price"]["column"]
|
|
self.execution_price_shift_ = self.config_["execution_price"]["shift"]
|
|
else:
|
|
self.execution_price_column_ = None
|
|
self.execution_price_shift_ = 0
|
|
|
|
def has_next(self) -> bool:
|
|
return self.current_index_ < len(self.market_data_df_)
|
|
|
|
def get_next(self) -> pd.Series:
|
|
result = self.market_data_df_.iloc[self.current_index_]
|
|
self.current_index_ += 1
|
|
return result
|
|
|
|
def load(self) -> None:
|
|
datafiles: List[str] = self.config_.get("datafiles", [])
|
|
instruments: List[Dict[str, str]] = self.config_.get("instruments", [])
|
|
assert len(instruments) > 0, "No instruments found in config"
|
|
assert len(datafiles) > 0, "No datafiles found in config"
|
|
self.symbol_a_ = instruments[0]["symbol"]
|
|
self.symbol_b_ = instruments[1]["symbol"]
|
|
self.stat_model_price_ = self.config_["stat_model_price"]
|
|
|
|
extra_minutes: int
|
|
extra_minutes = self.execution_price_shift_
|
|
|
|
for datafile in datafiles:
|
|
md_df = load_market_data(
|
|
datafile=datafile,
|
|
instruments=instruments,
|
|
db_table_name=self.config_["market_data_loading"][instruments[0]["instrument_type"]]["db_table_name"],
|
|
trading_hours=self.config_["trading_hours"],
|
|
extra_minutes=extra_minutes,
|
|
)
|
|
self.origin_mkt_data_df_ = pd.concat([self.origin_mkt_data_df_, md_df])
|
|
|
|
self.origin_mkt_data_df_ = self.origin_mkt_data_df_.sort_values(by="tstamp")
|
|
self.origin_mkt_data_df_ = self.origin_mkt_data_df_.dropna().reset_index(drop=True)
|
|
self._set_market_data()
|
|
|
|
def _set_market_data(self, ) -> None:
|
|
if self.is_execution_price_:
|
|
self.market_data_df_ = pd.DataFrame(
|
|
self._transform_dataframe(self.origin_mkt_data_df_)[["tstamp"] + self.colnames() + self.orig_exec_prices_colnames()]
|
|
)
|
|
else:
|
|
self.market_data_df_ = pd.DataFrame(
|
|
self._transform_dataframe(self.origin_mkt_data_df_)[["tstamp"] + self.colnames()]
|
|
)
|
|
|
|
self.market_data_df_ = self.market_data_df_.dropna().reset_index(drop=True)
|
|
self.market_data_df_["tstamp"] = pd.to_datetime(self.market_data_df_["tstamp"])
|
|
self.market_data_df_ = self.market_data_df_.sort_values("tstamp")
|
|
self._set_execution_price_data()
|
|
|
|
def _transform_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
df_selected: pd.DataFrame
|
|
if self.is_execution_price_:
|
|
execution_price_column = self.config_["execution_price"]["column"]
|
|
|
|
df_selected = pd.DataFrame(
|
|
df[["tstamp", "symbol", self.stat_model_price_, execution_price_column]]
|
|
)
|
|
else:
|
|
df_selected = pd.DataFrame(
|
|
df[["tstamp", "symbol", self.stat_model_price_]]
|
|
)
|
|
|
|
result_df = pd.DataFrame(df_selected["tstamp"]).drop_duplicates().reset_index(drop=True)
|
|
|
|
# For each unique symbol, add a corresponding stat_model_price column
|
|
symbols = df_selected["symbol"].unique()
|
|
|
|
|
|
|
|
for symbol in symbols:
|
|
# Filter rows for this symbol
|
|
df_symbol = df_selected[df_selected["symbol"] == symbol].reset_index(
|
|
drop=True
|
|
)
|
|
|
|
# Create column name like "close-COIN"
|
|
new_price_column = f"{self.stat_model_price_}_{symbol}"
|
|
if self.is_execution_price_:
|
|
new_execution_price_column = f"{self.execution_price_column_}_{symbol}"
|
|
|
|
# Create temporary dataframe with timestamp and price
|
|
temp_df = pd.DataFrame(
|
|
{
|
|
"tstamp": df_symbol["tstamp"],
|
|
new_price_column: df_symbol[self.stat_model_price_],
|
|
new_execution_price_column: df_symbol[execution_price_column],
|
|
}
|
|
)
|
|
else:
|
|
temp_df = pd.DataFrame(
|
|
{
|
|
"tstamp": df_symbol["tstamp"],
|
|
new_price_column: df_symbol[self.stat_model_price_],
|
|
}
|
|
)
|
|
|
|
# Join with our result dataframe
|
|
result_df = pd.merge(result_df, temp_df, on="tstamp", how="left")
|
|
result_df = result_df.reset_index(
|
|
drop=True
|
|
) # do not dropna() since irrelevant symbol would affect dataset
|
|
|
|
return result_df.dropna()
|
|
|
|
def _set_execution_price_data(self) -> None:
|
|
if "execution_price" not in self.config_:
|
|
self.market_data_df_[f"exec_price_{self.symbol_a_}"] = self.market_data_df_[f"{self.stat_model_price_}_{self.symbol_a_}"]
|
|
self.market_data_df_[f"exec_price_{self.symbol_b_}"] = self.market_data_df_[f"{self.stat_model_price_}_{self.symbol_b_}"]
|
|
return
|
|
execution_price_column = self.config_["execution_price"]["column"]
|
|
execution_price_shift = self.config_["execution_price"]["shift"]
|
|
self.market_data_df_[f"exec_price_{self.symbol_a_}"] = self.market_data_df_[f"{execution_price_column}_{self.symbol_a_}"].shift(-execution_price_shift)
|
|
self.market_data_df_[f"exec_price_{self.symbol_b_}"] = self.market_data_df_[f"{execution_price_column}_{self.symbol_b_}"].shift(-execution_price_shift)
|
|
self.market_data_df_ = self.market_data_df_.dropna().reset_index(drop=True)
|
|
|
|
def colnames(self) -> List[str]:
|
|
return [
|
|
f"{self.stat_model_price_}_{self.symbol_a_}",
|
|
f"{self.stat_model_price_}_{self.symbol_b_}",
|
|
]
|
|
|
|
def orig_exec_prices_colnames(self) -> List[str]:
|
|
return [
|
|
f"{self.execution_price_column_}_{self.symbol_a_}",
|
|
f"{self.execution_price_column_}_{self.symbol_b_}",
|
|
]
|
|
|
|
def exec_prices_colnames(self) -> List[str]:
|
|
return [
|
|
f"exec_price_{self.symbol_a_}",
|
|
f"exec_price_{self.symbol_b_}",
|
|
]
|
|
|