From a0e06cfc06f65bf8f2c6f43985c20cff5896d978 Mon Sep 17 00:00:00 2001 From: Filip Stefaniuk Date: Sun, 15 Sep 2024 12:01:39 +0200 Subject: [PATCH] Add code for evaluating strategies --- src/strategy/evaluation.py | 60 ++++++++++++++++++++++++++++++++++---- src/strategy/plotting.py | 54 ++++++++++++++++++++++++++++++++++ src/strategy/strategy.py | 11 ++++++- 3 files changed, 119 insertions(+), 6 deletions(-) create mode 100644 src/strategy/plotting.py diff --git a/src/strategy/evaluation.py b/src/strategy/evaluation.py index 4a273cc..8db2160 100644 --- a/src/strategy/evaluation.py +++ b/src/strategy/evaluation.py @@ -1,13 +1,58 @@ +from typing import Dict, List, Any, Optional, Callable +import itertools import pandas as pd import numpy as np +import functools +from tqdm import tqdm +from multiprocessing import Pool from strategy import metrics from strategy.strategy import LONG_POSITION, SHORT_POSITION, EXIT_POSITION from strategy.strategy import StrategyBase +def parameter_sweep( + data: pd.DataFrame, + strategy_class: StrategyBase.__class__, + params: Dict[str, List[Any]], + num_workers: int = 4, + params_filter: Optional[Callable] = None, + log_every: int = 200, + exchange_fee: float = 0.001, + interval: str = '5min') -> pd.DataFrame: + """Evaluates the strategy on a different sets of hyperparameters.""" + + # Obtain sets of parameters to evaluate + param_sets = list(filter(params_filter, map(lambda p: dict( + zip(params.keys(), p)), itertools.product(*params.values())))) + + result = [] + total = len(param_sets) + + # Evaluate sets of different hyperparameters in parallel + with Pool(num_workers) as pool, tqdm(total=total) as pbar: + for chunk in (param_sets[i:i + log_every] + for i in range(0, total, log_every)): + tmp = list( + pool.map( + functools.partial( + evaluate_strategy, + data, + exchange_fee=exchange_fee, + interval=interval, + include_arrays=False), + map( + lambda p: strategy_class( + **p), chunk))) + pbar.update(len(tmp)) + result += tmp + + return pd.DataFrame(result) + + def evaluate_strategy( data: pd.DataFrame, strategy: StrategyBase, + include_arrays: bool = True, exchange_fee: float = 0.001, interval: str = "5min"): """Evaluates a trading strategy.""" @@ -57,11 +102,16 @@ def evaluate_strategy( np.append(positions[1:], [EXIT_POSITION])), 'long_pos': np.sum(positions == LONG_POSITION) / positions.size, 'short_pos': np.sum(positions == SHORT_POSITION) / positions.size, - # Arrays - 'portfolio_value': portfolio_value, - 'strategy_returns': strategy_returns, - 'strategy_positions': np.append([EXIT_POSITION], positions), - 'time': timestamps } + result |= strategy.info() + + if include_arrays: + result |= { + 'portfolio_value': portfolio_value, + 'strategy_returns': strategy_returns, + 'strategy_positions': np.append([EXIT_POSITION], positions), + 'time': timestamps + } + return result diff --git a/src/strategy/plotting.py b/src/strategy/plotting.py new file mode 100644 index 0000000..ed79b4c --- /dev/null +++ b/src/strategy/plotting.py @@ -0,0 +1,54 @@ +from typing import List +import pandas as pd +import plotly.figure_factory as ff +import plotly.express as px + + +def plot_sweep_results( + sweep_results: pd.DataFrame, + parameters: List[str], + objective: str = 'value', + top_n: int = 5, + title: str = "Hyperparameters search results"): + """Helper function for plotting results of hyperparameter search.""" + data = sweep_results[list(parameters) + [objective]].round(2) + + fig = ff.create_table( + data.sort_values( + objective, + ascending=False).head(top_n), + height_constant=80) + fig.layout.yaxis.update({'domain': [0, .4]}) + + parcoords = px.parallel_coordinates( + data, + color=objective, + color_continuous_midpoint=1.0, + color_continuous_scale=px.colors.diverging.Tealrose_r) + parcoords.data[0].domain.update({'x': [0.05, 0.8], 'y': [0.5, 0.90]}) + + fig.add_trace(parcoords.data[0]) + fig.layout.update({'coloraxis': parcoords.layout.coloraxis}) + fig.update_layout(coloraxis_colorbar=dict( + yanchor="top", + xanchor='right', + y=1, + x=0.95, + len=0.5, + thickness=40, + )) + + fig.layout.margin.update({'l': 20, 'r': 20, 'b': 20, 't': 40}) + fig.update_layout( + title={ + 'text': title, + 'y': 0.98, + 'x': 0.5, + 'xanchor': 'center', + 'yanchor': 'top', + 'font': { + 'size': 28 + }} + ) + + return fig diff --git a/src/strategy/strategy.py b/src/strategy/strategy.py index cd367b9..aa38150 100644 --- a/src/strategy/strategy.py +++ b/src/strategy/strategy.py @@ -1,6 +1,6 @@ import numpy as np import pandas as pd -import logging +# import logging from typing import Dict, Any EXIT_POSITION = 0 @@ -101,6 +101,15 @@ class ModelQuantilePredictionsStrategy(ModelPredictionsStrategyBase): self.quantile_enter_short = quantile_enter_short self.quantile_exit_short = quantile_exit_short + def info(self): + return super().info() | { + 'quantiles': self.quantiles, + 'quantile_enter_long': self.quantile_enter_long, + 'quantile_exit_long': self.quantile_exit_long, + 'quantile_enter_short': self.quantile_enter_short, + 'quantile_exit_short': self.quantile_exit_short + } + def get_positions(self, data): arr_preds = data['prediction'].to_numpy() arr_target = data[self.target].to_numpy()