diff --git a/src/strategy/strategy.py b/src/strategy/strategy.py index 3aa454b..cd367b9 100644 --- a/src/strategy/strategy.py +++ b/src/strategy/strategy.py @@ -1,5 +1,6 @@ import numpy as np import pandas as pd +import logging from typing import Dict, Any EXIT_POSITION = 0 @@ -36,19 +37,29 @@ class BuyAndHoldStrategy(StrategyBase): class ModelPredictionsStrategyBase(StrategyBase): """Base class for strategies based on model predictions.""" + def __init__(self, predictions, - name: str = None): + name: str = None, + future: int = 1, + exchange_fee: int = 0.001, + target: str = 'close_price'): self.predictions = predictions assert 'time_index' in self.predictions.columns assert 'group_id' in self.predictions.columns assert 'prediction' in self.predictions.columns self.name = name - + self.future = future + self.target = target + self.exchange_fee = exchange_fee def info(self): - return {'strategy_name': self.name or 'Unknown model'} + return { + 'strategy_name': self.name or 'Unknown model', + 'future': self.future, + 'target': self.target + } def run(self, data): # Adds predictions to data, if prediction is unknown for a given @@ -56,71 +67,90 @@ class ModelPredictionsStrategyBase(StrategyBase): merged_data = pd.merge( data, self.predictions, on=['time_index', 'group_id'], how='left') - + return self.get_positions(merged_data) def get_positions(self, data): raise NotImplementedError() -class ReturnsPredictionStrategy(ModelPredictionsStrategyBase): - """Strategy that selects position based on returns predictions.""" - - def __init__( - self, - predictions, - threshold=0.001, - name=None): - super().__init__(predictions, name=name) - self.threshold = threshold - - def get_positions(self, data): - arr = data['prediction'] - positions = [] - for i in range(len(arr)): - if arr[i] > self.threshold: - positions.append(LONG_POSITION) - elif arr[i] < -self.threshold: - positions.append(EXIT_POSITION) - elif not len(positions): - positions.append(EXIT_POSITION) - else: - positions.append(positions[-1]) - - return np.array(positions, dtype=np.int32) - - -class PriceQuantilePredictionStrategy(ModelPredictionsStrategyBase): +class ModelQuantilePredictionsStrategy(ModelPredictionsStrategyBase): def __init__( self, predictions, - name=None): - super().__init__(predictions, name=name) + quantiles, + quantile_enter_long=None, + quantile_exit_long=None, + quantile_enter_short=None, + quantile_exit_short=None, + name: str = None, + future: int = 1, + target: str = 'close_price', + exchange_fee: int = 0.001 + ): + super().__init__( + predictions, + name=name, + future=future, + target=target, + exchange_fee=exchange_fee) - def info(self): - return {'strategy_name': self.name} + self.quantiles = quantiles + self.quantile_enter_long = quantile_enter_long + self.quantile_exit_long = quantile_exit_long + self.quantile_enter_short = quantile_enter_short + self.quantile_exit_short = quantile_exit_short def get_positions(self, data): - arr_preds = data['prediction'].to_numpy() - arr_close_price = data['close_price'].to_numpy() + arr_target = data[self.target].to_numpy() - positions = [] + positions = [EXIT_POSITION] for i in range(len(arr_preds)): - if not np.isnan(arr_preds[i]).any(): - price = arr_close_price[i] - pred_low = arr_preds[i][0] - pred_high = arr_preds[i][-1] - if (pred_low - price) / price > 0.001: - positions.append(LONG_POSITION) - continue - elif (pred_high - price) / price < -0.001: - positions.append(EXIT_POSITION) - continue - - if not len(positions): + + # If strategy does not have prediction + # keep the current position. + if np.isnan(arr_preds[i]).any(): + # logging.warning(f"Missing value for time index {i}.") + positions.append(positions[-1]) + continue + + target = arr_target[i] + prediction = arr_preds[i][self.future - 1] + + # Enter long position + if (self.quantile_enter_long and + (prediction[self.get_quantile_idx( + round(1 - self.quantile_enter_long, 2) + )] - target) + / target > self.exchange_fee): + positions.append(LONG_POSITION) + + # Enter short position + elif (self.quantile_enter_short and + (prediction[self.get_quantile_idx( + self.quantile_enter_short)] - target) + / target < -self.exchange_fee): + positions.append(SHORT_POSITION) + + # Exit long position + elif (self.quantile_exit_long and + (prediction[self.get_quantile_idx( + self.quantile_exit_long)] - target) + / target < -self.exchange_fee): positions.append(EXIT_POSITION) + + # Exit short postion + elif (self.quantile_exit_short and + (prediction[self.get_quantile_idx( + round(1 - self.quantile_exit_short, 2) + )] - target) / target > self.exchange_fee): + positions.append(EXIT_POSITION) + else: positions.append(positions[-1]) - return np.array(positions, dtype=np.int32) \ No newline at end of file + return np.array(positions[1:], dtype=np.int32) + + def get_quantile_idx(self, quantile): + return self.quantiles.index(quantile)