update strategy for quantile models
This commit is contained in:
parent
134860d09e
commit
f9b17473cb
@ -1,5 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import logging
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
EXIT_POSITION = 0
|
EXIT_POSITION = 0
|
||||||
@ -36,19 +37,29 @@ class BuyAndHoldStrategy(StrategyBase):
|
|||||||
|
|
||||||
class ModelPredictionsStrategyBase(StrategyBase):
|
class ModelPredictionsStrategyBase(StrategyBase):
|
||||||
"""Base class for strategies based on model predictions."""
|
"""Base class for strategies based on model predictions."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
predictions,
|
predictions,
|
||||||
name: str = None):
|
name: str = None,
|
||||||
|
future: int = 1,
|
||||||
|
exchange_fee: int = 0.001,
|
||||||
|
target: str = 'close_price'):
|
||||||
self.predictions = predictions
|
self.predictions = predictions
|
||||||
assert 'time_index' in self.predictions.columns
|
assert 'time_index' in self.predictions.columns
|
||||||
assert 'group_id' in self.predictions.columns
|
assert 'group_id' in self.predictions.columns
|
||||||
assert 'prediction' in self.predictions.columns
|
assert 'prediction' in self.predictions.columns
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.future = future
|
||||||
|
self.target = target
|
||||||
|
self.exchange_fee = exchange_fee
|
||||||
|
|
||||||
def info(self):
|
def info(self):
|
||||||
return {'strategy_name': self.name or 'Unknown model'}
|
return {
|
||||||
|
'strategy_name': self.name or 'Unknown model',
|
||||||
|
'future': self.future,
|
||||||
|
'target': self.target
|
||||||
|
}
|
||||||
|
|
||||||
def run(self, data):
|
def run(self, data):
|
||||||
# Adds predictions to data, if prediction is unknown for a given
|
# Adds predictions to data, if prediction is unknown for a given
|
||||||
@ -56,71 +67,90 @@ class ModelPredictionsStrategyBase(StrategyBase):
|
|||||||
merged_data = pd.merge(
|
merged_data = pd.merge(
|
||||||
data, self.predictions, on=['time_index', 'group_id'],
|
data, self.predictions, on=['time_index', 'group_id'],
|
||||||
how='left')
|
how='left')
|
||||||
|
|
||||||
return self.get_positions(merged_data)
|
return self.get_positions(merged_data)
|
||||||
|
|
||||||
def get_positions(self, data):
|
def get_positions(self, data):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class ReturnsPredictionStrategy(ModelPredictionsStrategyBase):
|
class ModelQuantilePredictionsStrategy(ModelPredictionsStrategyBase):
|
||||||
"""Strategy that selects position based on returns predictions."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
predictions,
|
|
||||||
threshold=0.001,
|
|
||||||
name=None):
|
|
||||||
super().__init__(predictions, name=name)
|
|
||||||
self.threshold = threshold
|
|
||||||
|
|
||||||
def get_positions(self, data):
|
|
||||||
arr = data['prediction']
|
|
||||||
positions = []
|
|
||||||
for i in range(len(arr)):
|
|
||||||
if arr[i] > self.threshold:
|
|
||||||
positions.append(LONG_POSITION)
|
|
||||||
elif arr[i] < -self.threshold:
|
|
||||||
positions.append(EXIT_POSITION)
|
|
||||||
elif not len(positions):
|
|
||||||
positions.append(EXIT_POSITION)
|
|
||||||
else:
|
|
||||||
positions.append(positions[-1])
|
|
||||||
|
|
||||||
return np.array(positions, dtype=np.int32)
|
|
||||||
|
|
||||||
|
|
||||||
class PriceQuantilePredictionStrategy(ModelPredictionsStrategyBase):
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
predictions,
|
predictions,
|
||||||
name=None):
|
quantiles,
|
||||||
super().__init__(predictions, name=name)
|
quantile_enter_long=None,
|
||||||
|
quantile_exit_long=None,
|
||||||
|
quantile_enter_short=None,
|
||||||
|
quantile_exit_short=None,
|
||||||
|
name: str = None,
|
||||||
|
future: int = 1,
|
||||||
|
target: str = 'close_price',
|
||||||
|
exchange_fee: int = 0.001
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
predictions,
|
||||||
|
name=name,
|
||||||
|
future=future,
|
||||||
|
target=target,
|
||||||
|
exchange_fee=exchange_fee)
|
||||||
|
|
||||||
def info(self):
|
self.quantiles = quantiles
|
||||||
return {'strategy_name': self.name}
|
self.quantile_enter_long = quantile_enter_long
|
||||||
|
self.quantile_exit_long = quantile_exit_long
|
||||||
|
self.quantile_enter_short = quantile_enter_short
|
||||||
|
self.quantile_exit_short = quantile_exit_short
|
||||||
|
|
||||||
def get_positions(self, data):
|
def get_positions(self, data):
|
||||||
|
|
||||||
arr_preds = data['prediction'].to_numpy()
|
arr_preds = data['prediction'].to_numpy()
|
||||||
arr_close_price = data['close_price'].to_numpy()
|
arr_target = data[self.target].to_numpy()
|
||||||
|
|
||||||
positions = []
|
positions = [EXIT_POSITION]
|
||||||
for i in range(len(arr_preds)):
|
for i in range(len(arr_preds)):
|
||||||
if not np.isnan(arr_preds[i]).any():
|
|
||||||
price = arr_close_price[i]
|
# If strategy does not have prediction
|
||||||
pred_low = arr_preds[i][0]
|
# keep the current position.
|
||||||
pred_high = arr_preds[i][-1]
|
if np.isnan(arr_preds[i]).any():
|
||||||
if (pred_low - price) / price > 0.001:
|
# logging.warning(f"Missing value for time index {i}.")
|
||||||
positions.append(LONG_POSITION)
|
positions.append(positions[-1])
|
||||||
continue
|
continue
|
||||||
elif (pred_high - price) / price < -0.001:
|
|
||||||
positions.append(EXIT_POSITION)
|
target = arr_target[i]
|
||||||
continue
|
prediction = arr_preds[i][self.future - 1]
|
||||||
|
|
||||||
if not len(positions):
|
# Enter long position
|
||||||
|
if (self.quantile_enter_long and
|
||||||
|
(prediction[self.get_quantile_idx(
|
||||||
|
round(1 - self.quantile_enter_long, 2)
|
||||||
|
)] - target)
|
||||||
|
/ target > self.exchange_fee):
|
||||||
|
positions.append(LONG_POSITION)
|
||||||
|
|
||||||
|
# Enter short position
|
||||||
|
elif (self.quantile_enter_short and
|
||||||
|
(prediction[self.get_quantile_idx(
|
||||||
|
self.quantile_enter_short)] - target)
|
||||||
|
/ target < -self.exchange_fee):
|
||||||
|
positions.append(SHORT_POSITION)
|
||||||
|
|
||||||
|
# Exit long position
|
||||||
|
elif (self.quantile_exit_long and
|
||||||
|
(prediction[self.get_quantile_idx(
|
||||||
|
self.quantile_exit_long)] - target)
|
||||||
|
/ target < -self.exchange_fee):
|
||||||
positions.append(EXIT_POSITION)
|
positions.append(EXIT_POSITION)
|
||||||
|
|
||||||
|
# Exit short postion
|
||||||
|
elif (self.quantile_exit_short and
|
||||||
|
(prediction[self.get_quantile_idx(
|
||||||
|
round(1 - self.quantile_exit_short, 2)
|
||||||
|
)] - target) / target > self.exchange_fee):
|
||||||
|
positions.append(EXIT_POSITION)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
positions.append(positions[-1])
|
positions.append(positions[-1])
|
||||||
|
|
||||||
return np.array(positions, dtype=np.int32)
|
return np.array(positions[1:], dtype=np.int32)
|
||||||
|
|
||||||
|
def get_quantile_idx(self, quantile):
|
||||||
|
return self.quantiles.index(quantile)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user