Add code for evaluating strategies
This commit is contained in:
parent
b3aacb65a3
commit
a0e06cfc06
@ -1,13 +1,58 @@
|
|||||||
|
from typing import Dict, List, Any, Optional, Callable
|
||||||
|
import itertools
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import functools
|
||||||
|
from tqdm import tqdm
|
||||||
|
from multiprocessing import Pool
|
||||||
from strategy import metrics
|
from strategy import metrics
|
||||||
from strategy.strategy import LONG_POSITION, SHORT_POSITION, EXIT_POSITION
|
from strategy.strategy import LONG_POSITION, SHORT_POSITION, EXIT_POSITION
|
||||||
from strategy.strategy import StrategyBase
|
from strategy.strategy import StrategyBase
|
||||||
|
|
||||||
|
|
||||||
|
def parameter_sweep(
|
||||||
|
data: pd.DataFrame,
|
||||||
|
strategy_class: StrategyBase.__class__,
|
||||||
|
params: Dict[str, List[Any]],
|
||||||
|
num_workers: int = 4,
|
||||||
|
params_filter: Optional[Callable] = None,
|
||||||
|
log_every: int = 200,
|
||||||
|
exchange_fee: float = 0.001,
|
||||||
|
interval: str = '5min') -> pd.DataFrame:
|
||||||
|
"""Evaluates the strategy on a different sets of hyperparameters."""
|
||||||
|
|
||||||
|
# Obtain sets of parameters to evaluate
|
||||||
|
param_sets = list(filter(params_filter, map(lambda p: dict(
|
||||||
|
zip(params.keys(), p)), itertools.product(*params.values()))))
|
||||||
|
|
||||||
|
result = []
|
||||||
|
total = len(param_sets)
|
||||||
|
|
||||||
|
# Evaluate sets of different hyperparameters in parallel
|
||||||
|
with Pool(num_workers) as pool, tqdm(total=total) as pbar:
|
||||||
|
for chunk in (param_sets[i:i + log_every]
|
||||||
|
for i in range(0, total, log_every)):
|
||||||
|
tmp = list(
|
||||||
|
pool.map(
|
||||||
|
functools.partial(
|
||||||
|
evaluate_strategy,
|
||||||
|
data,
|
||||||
|
exchange_fee=exchange_fee,
|
||||||
|
interval=interval,
|
||||||
|
include_arrays=False),
|
||||||
|
map(
|
||||||
|
lambda p: strategy_class(
|
||||||
|
**p), chunk)))
|
||||||
|
pbar.update(len(tmp))
|
||||||
|
result += tmp
|
||||||
|
|
||||||
|
return pd.DataFrame(result)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_strategy(
|
def evaluate_strategy(
|
||||||
data: pd.DataFrame,
|
data: pd.DataFrame,
|
||||||
strategy: StrategyBase,
|
strategy: StrategyBase,
|
||||||
|
include_arrays: bool = True,
|
||||||
exchange_fee: float = 0.001,
|
exchange_fee: float = 0.001,
|
||||||
interval: str = "5min"):
|
interval: str = "5min"):
|
||||||
"""Evaluates a trading strategy."""
|
"""Evaluates a trading strategy."""
|
||||||
@ -57,7 +102,12 @@ def evaluate_strategy(
|
|||||||
np.append(positions[1:], [EXIT_POSITION])),
|
np.append(positions[1:], [EXIT_POSITION])),
|
||||||
'long_pos': np.sum(positions == LONG_POSITION) / positions.size,
|
'long_pos': np.sum(positions == LONG_POSITION) / positions.size,
|
||||||
'short_pos': np.sum(positions == SHORT_POSITION) / positions.size,
|
'short_pos': np.sum(positions == SHORT_POSITION) / positions.size,
|
||||||
# Arrays
|
}
|
||||||
|
|
||||||
|
result |= strategy.info()
|
||||||
|
|
||||||
|
if include_arrays:
|
||||||
|
result |= {
|
||||||
'portfolio_value': portfolio_value,
|
'portfolio_value': portfolio_value,
|
||||||
'strategy_returns': strategy_returns,
|
'strategy_returns': strategy_returns,
|
||||||
'strategy_positions': np.append([EXIT_POSITION], positions),
|
'strategy_positions': np.append([EXIT_POSITION], positions),
|
||||||
|
|||||||
54
src/strategy/plotting.py
Normal file
54
src/strategy/plotting.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from typing import List
|
||||||
|
import pandas as pd
|
||||||
|
import plotly.figure_factory as ff
|
||||||
|
import plotly.express as px
|
||||||
|
|
||||||
|
|
||||||
|
def plot_sweep_results(
|
||||||
|
sweep_results: pd.DataFrame,
|
||||||
|
parameters: List[str],
|
||||||
|
objective: str = 'value',
|
||||||
|
top_n: int = 5,
|
||||||
|
title: str = "Hyperparameters search results"):
|
||||||
|
"""Helper function for plotting results of hyperparameter search."""
|
||||||
|
data = sweep_results[list(parameters) + [objective]].round(2)
|
||||||
|
|
||||||
|
fig = ff.create_table(
|
||||||
|
data.sort_values(
|
||||||
|
objective,
|
||||||
|
ascending=False).head(top_n),
|
||||||
|
height_constant=80)
|
||||||
|
fig.layout.yaxis.update({'domain': [0, .4]})
|
||||||
|
|
||||||
|
parcoords = px.parallel_coordinates(
|
||||||
|
data,
|
||||||
|
color=objective,
|
||||||
|
color_continuous_midpoint=1.0,
|
||||||
|
color_continuous_scale=px.colors.diverging.Tealrose_r)
|
||||||
|
parcoords.data[0].domain.update({'x': [0.05, 0.8], 'y': [0.5, 0.90]})
|
||||||
|
|
||||||
|
fig.add_trace(parcoords.data[0])
|
||||||
|
fig.layout.update({'coloraxis': parcoords.layout.coloraxis})
|
||||||
|
fig.update_layout(coloraxis_colorbar=dict(
|
||||||
|
yanchor="top",
|
||||||
|
xanchor='right',
|
||||||
|
y=1,
|
||||||
|
x=0.95,
|
||||||
|
len=0.5,
|
||||||
|
thickness=40,
|
||||||
|
))
|
||||||
|
|
||||||
|
fig.layout.margin.update({'l': 20, 'r': 20, 'b': 20, 't': 40})
|
||||||
|
fig.update_layout(
|
||||||
|
title={
|
||||||
|
'text': title,
|
||||||
|
'y': 0.98,
|
||||||
|
'x': 0.5,
|
||||||
|
'xanchor': 'center',
|
||||||
|
'yanchor': 'top',
|
||||||
|
'font': {
|
||||||
|
'size': 28
|
||||||
|
}}
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
@ -1,6 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import logging
|
# import logging
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
EXIT_POSITION = 0
|
EXIT_POSITION = 0
|
||||||
@ -101,6 +101,15 @@ class ModelQuantilePredictionsStrategy(ModelPredictionsStrategyBase):
|
|||||||
self.quantile_enter_short = quantile_enter_short
|
self.quantile_enter_short = quantile_enter_short
|
||||||
self.quantile_exit_short = quantile_exit_short
|
self.quantile_exit_short = quantile_exit_short
|
||||||
|
|
||||||
|
def info(self):
|
||||||
|
return super().info() | {
|
||||||
|
'quantiles': self.quantiles,
|
||||||
|
'quantile_enter_long': self.quantile_enter_long,
|
||||||
|
'quantile_exit_long': self.quantile_exit_long,
|
||||||
|
'quantile_enter_short': self.quantile_enter_short,
|
||||||
|
'quantile_exit_short': self.quantile_exit_short
|
||||||
|
}
|
||||||
|
|
||||||
def get_positions(self, data):
|
def get_positions(self, data):
|
||||||
arr_preds = data['prediction'].to_numpy()
|
arr_preds = data['prediction'].to_numpy()
|
||||||
arr_target = data[self.target].to_numpy()
|
arr_target = data[self.target].to_numpy()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user