update strategy for quantile models

This commit is contained in:
Filip Stefaniuk 2024-09-14 20:30:19 +02:00
parent 134860d09e
commit f9b17473cb

View File

@ -1,5 +1,6 @@
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import logging
from typing import Dict, Any from typing import Dict, Any
EXIT_POSITION = 0 EXIT_POSITION = 0
@ -36,19 +37,29 @@ class BuyAndHoldStrategy(StrategyBase):
class ModelPredictionsStrategyBase(StrategyBase): class ModelPredictionsStrategyBase(StrategyBase):
"""Base class for strategies based on model predictions.""" """Base class for strategies based on model predictions."""
def __init__(self, def __init__(self,
predictions, predictions,
name: str = None): name: str = None,
future: int = 1,
exchange_fee: int = 0.001,
target: str = 'close_price'):
self.predictions = predictions self.predictions = predictions
assert 'time_index' in self.predictions.columns assert 'time_index' in self.predictions.columns
assert 'group_id' in self.predictions.columns assert 'group_id' in self.predictions.columns
assert 'prediction' in self.predictions.columns assert 'prediction' in self.predictions.columns
self.name = name self.name = name
self.future = future
self.target = target
self.exchange_fee = exchange_fee
def info(self): def info(self):
return {'strategy_name': self.name or 'Unknown model'} return {
'strategy_name': self.name or 'Unknown model',
'future': self.future,
'target': self.target
}
def run(self, data): def run(self, data):
# Adds predictions to data, if prediction is unknown for a given # Adds predictions to data, if prediction is unknown for a given
@ -63,64 +74,83 @@ class ModelPredictionsStrategyBase(StrategyBase):
raise NotImplementedError() raise NotImplementedError()
class ReturnsPredictionStrategy(ModelPredictionsStrategyBase): class ModelQuantilePredictionsStrategy(ModelPredictionsStrategyBase):
"""Strategy that selects position based on returns predictions."""
def __init__( def __init__(
self, self,
predictions, predictions,
threshold=0.001, quantiles,
name=None): quantile_enter_long=None,
super().__init__(predictions, name=name) quantile_exit_long=None,
self.threshold = threshold quantile_enter_short=None,
quantile_exit_short=None,
def get_positions(self, data): name: str = None,
arr = data['prediction'] future: int = 1,
positions = [] target: str = 'close_price',
for i in range(len(arr)): exchange_fee: int = 0.001
if arr[i] > self.threshold: ):
positions.append(LONG_POSITION) super().__init__(
elif arr[i] < -self.threshold:
positions.append(EXIT_POSITION)
elif not len(positions):
positions.append(EXIT_POSITION)
else:
positions.append(positions[-1])
return np.array(positions, dtype=np.int32)
class PriceQuantilePredictionStrategy(ModelPredictionsStrategyBase):
def __init__(
self,
predictions, predictions,
name=None): name=name,
super().__init__(predictions, name=name) future=future,
target=target,
exchange_fee=exchange_fee)
def info(self): self.quantiles = quantiles
return {'strategy_name': self.name} self.quantile_enter_long = quantile_enter_long
self.quantile_exit_long = quantile_exit_long
self.quantile_enter_short = quantile_enter_short
self.quantile_exit_short = quantile_exit_short
def get_positions(self, data): def get_positions(self, data):
arr_preds = data['prediction'].to_numpy() arr_preds = data['prediction'].to_numpy()
arr_close_price = data['close_price'].to_numpy() arr_target = data[self.target].to_numpy()
positions = [] positions = [EXIT_POSITION]
for i in range(len(arr_preds)): for i in range(len(arr_preds)):
if not np.isnan(arr_preds[i]).any():
price = arr_close_price[i] # If strategy does not have prediction
pred_low = arr_preds[i][0] # keep the current position.
pred_high = arr_preds[i][-1] if np.isnan(arr_preds[i]).any():
if (pred_low - price) / price > 0.001: # logging.warning(f"Missing value for time index {i}.")
positions.append(LONG_POSITION) positions.append(positions[-1])
continue
elif (pred_high - price) / price < -0.001:
positions.append(EXIT_POSITION)
continue continue
if not len(positions): target = arr_target[i]
prediction = arr_preds[i][self.future - 1]
# Enter long position
if (self.quantile_enter_long and
(prediction[self.get_quantile_idx(
round(1 - self.quantile_enter_long, 2)
)] - target)
/ target > self.exchange_fee):
positions.append(LONG_POSITION)
# Enter short position
elif (self.quantile_enter_short and
(prediction[self.get_quantile_idx(
self.quantile_enter_short)] - target)
/ target < -self.exchange_fee):
positions.append(SHORT_POSITION)
# Exit long position
elif (self.quantile_exit_long and
(prediction[self.get_quantile_idx(
self.quantile_exit_long)] - target)
/ target < -self.exchange_fee):
positions.append(EXIT_POSITION) positions.append(EXIT_POSITION)
# Exit short postion
elif (self.quantile_exit_short and
(prediction[self.get_quantile_idx(
round(1 - self.quantile_exit_short, 2)
)] - target) / target > self.exchange_fee):
positions.append(EXIT_POSITION)
else: else:
positions.append(positions[-1]) positions.append(positions[-1])
return np.array(positions, dtype=np.int32) return np.array(positions[1:], dtype=np.int32)
def get_quantile_idx(self, quantile):
return self.quantiles.index(quantile)