49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Any, Dict, Optional, cast, Generator, List
|
|
|
|
import pandas as pd
|
|
|
|
from pt_strategy.trading_pair import TradingPair
|
|
|
|
@dataclass
|
|
class Prediction:
|
|
tstamp_: pd.Timestamp
|
|
disequilibrium_: float
|
|
scaled_disequilibrium_: float
|
|
pair_: TradingPair
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"tstamp": self.tstamp_,
|
|
"disequilibrium": self.disequilibrium_,
|
|
"signed_scaled_disequilibrium": self.scaled_disequilibrium_,
|
|
"scaled_disequilibrium": abs(self.scaled_disequilibrium_),
|
|
"pair": self.pair_,
|
|
}
|
|
def to_pd_series(self) -> pd.Series:
|
|
return pd.DataFrame([self.to_dict()]).iloc[0]
|
|
|
|
class PairsTradingModel(ABC):
|
|
|
|
@abstractmethod
|
|
def predict(self, pair: TradingPair) -> Prediction:
|
|
...
|
|
|
|
@staticmethod
|
|
def create(config: Dict[str, Any]) -> PairsTradingModel:
|
|
import importlib
|
|
|
|
model_class_name = config.get("model_class", None)
|
|
assert model_class_name is not None
|
|
module_name, class_name = model_class_name.rsplit(".", 1)
|
|
module = importlib.import_module(module_name)
|
|
model_object = getattr(module, class_name)()
|
|
return cast(PairsTradingModel, model_object)
|
|
|
|
|
|
|