Add code for evaluating strategies
This commit is contained in:
parent
b3aacb65a3
commit
a0e06cfc06
@ -1,13 +1,58 @@
|
||||
from typing import Dict, List, Any, Optional, Callable
|
||||
import itertools
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import functools
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Pool
|
||||
from strategy import metrics
|
||||
from strategy.strategy import LONG_POSITION, SHORT_POSITION, EXIT_POSITION
|
||||
from strategy.strategy import StrategyBase
|
||||
|
||||
|
||||
def parameter_sweep(
|
||||
data: pd.DataFrame,
|
||||
strategy_class: StrategyBase.__class__,
|
||||
params: Dict[str, List[Any]],
|
||||
num_workers: int = 4,
|
||||
params_filter: Optional[Callable] = None,
|
||||
log_every: int = 200,
|
||||
exchange_fee: float = 0.001,
|
||||
interval: str = '5min') -> pd.DataFrame:
|
||||
"""Evaluates the strategy on a different sets of hyperparameters."""
|
||||
|
||||
# Obtain sets of parameters to evaluate
|
||||
param_sets = list(filter(params_filter, map(lambda p: dict(
|
||||
zip(params.keys(), p)), itertools.product(*params.values()))))
|
||||
|
||||
result = []
|
||||
total = len(param_sets)
|
||||
|
||||
# Evaluate sets of different hyperparameters in parallel
|
||||
with Pool(num_workers) as pool, tqdm(total=total) as pbar:
|
||||
for chunk in (param_sets[i:i + log_every]
|
||||
for i in range(0, total, log_every)):
|
||||
tmp = list(
|
||||
pool.map(
|
||||
functools.partial(
|
||||
evaluate_strategy,
|
||||
data,
|
||||
exchange_fee=exchange_fee,
|
||||
interval=interval,
|
||||
include_arrays=False),
|
||||
map(
|
||||
lambda p: strategy_class(
|
||||
**p), chunk)))
|
||||
pbar.update(len(tmp))
|
||||
result += tmp
|
||||
|
||||
return pd.DataFrame(result)
|
||||
|
||||
|
||||
def evaluate_strategy(
|
||||
data: pd.DataFrame,
|
||||
strategy: StrategyBase,
|
||||
include_arrays: bool = True,
|
||||
exchange_fee: float = 0.001,
|
||||
interval: str = "5min"):
|
||||
"""Evaluates a trading strategy."""
|
||||
@ -57,11 +102,16 @@ def evaluate_strategy(
|
||||
np.append(positions[1:], [EXIT_POSITION])),
|
||||
'long_pos': np.sum(positions == LONG_POSITION) / positions.size,
|
||||
'short_pos': np.sum(positions == SHORT_POSITION) / positions.size,
|
||||
# Arrays
|
||||
'portfolio_value': portfolio_value,
|
||||
'strategy_returns': strategy_returns,
|
||||
'strategy_positions': np.append([EXIT_POSITION], positions),
|
||||
'time': timestamps
|
||||
}
|
||||
|
||||
result |= strategy.info()
|
||||
|
||||
if include_arrays:
|
||||
result |= {
|
||||
'portfolio_value': portfolio_value,
|
||||
'strategy_returns': strategy_returns,
|
||||
'strategy_positions': np.append([EXIT_POSITION], positions),
|
||||
'time': timestamps
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
54
src/strategy/plotting.py
Normal file
54
src/strategy/plotting.py
Normal file
@ -0,0 +1,54 @@
|
||||
from typing import List
|
||||
import pandas as pd
|
||||
import plotly.figure_factory as ff
|
||||
import plotly.express as px
|
||||
|
||||
|
||||
def plot_sweep_results(
|
||||
sweep_results: pd.DataFrame,
|
||||
parameters: List[str],
|
||||
objective: str = 'value',
|
||||
top_n: int = 5,
|
||||
title: str = "Hyperparameters search results"):
|
||||
"""Helper function for plotting results of hyperparameter search."""
|
||||
data = sweep_results[list(parameters) + [objective]].round(2)
|
||||
|
||||
fig = ff.create_table(
|
||||
data.sort_values(
|
||||
objective,
|
||||
ascending=False).head(top_n),
|
||||
height_constant=80)
|
||||
fig.layout.yaxis.update({'domain': [0, .4]})
|
||||
|
||||
parcoords = px.parallel_coordinates(
|
||||
data,
|
||||
color=objective,
|
||||
color_continuous_midpoint=1.0,
|
||||
color_continuous_scale=px.colors.diverging.Tealrose_r)
|
||||
parcoords.data[0].domain.update({'x': [0.05, 0.8], 'y': [0.5, 0.90]})
|
||||
|
||||
fig.add_trace(parcoords.data[0])
|
||||
fig.layout.update({'coloraxis': parcoords.layout.coloraxis})
|
||||
fig.update_layout(coloraxis_colorbar=dict(
|
||||
yanchor="top",
|
||||
xanchor='right',
|
||||
y=1,
|
||||
x=0.95,
|
||||
len=0.5,
|
||||
thickness=40,
|
||||
))
|
||||
|
||||
fig.layout.margin.update({'l': 20, 'r': 20, 'b': 20, 't': 40})
|
||||
fig.update_layout(
|
||||
title={
|
||||
'text': title,
|
||||
'y': 0.98,
|
||||
'x': 0.5,
|
||||
'xanchor': 'center',
|
||||
'yanchor': 'top',
|
||||
'font': {
|
||||
'size': 28
|
||||
}}
|
||||
)
|
||||
|
||||
return fig
|
||||
@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import logging
|
||||
# import logging
|
||||
from typing import Dict, Any
|
||||
|
||||
EXIT_POSITION = 0
|
||||
@ -101,6 +101,15 @@ class ModelQuantilePredictionsStrategy(ModelPredictionsStrategyBase):
|
||||
self.quantile_enter_short = quantile_enter_short
|
||||
self.quantile_exit_short = quantile_exit_short
|
||||
|
||||
def info(self):
|
||||
return super().info() | {
|
||||
'quantiles': self.quantiles,
|
||||
'quantile_enter_long': self.quantile_enter_long,
|
||||
'quantile_exit_long': self.quantile_exit_long,
|
||||
'quantile_enter_short': self.quantile_enter_short,
|
||||
'quantile_exit_short': self.quantile_exit_short
|
||||
}
|
||||
|
||||
def get_positions(self, data):
|
||||
arr_preds = data['prediction'].to_numpy()
|
||||
arr_target = data[self.target].to_numpy()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user