diff --git a/src/strategy/strategy.py b/src/strategy/strategy.py index aa38150..b40a827 100644 --- a/src/strategy/strategy.py +++ b/src/strategy/strategy.py @@ -163,3 +163,29 @@ class ModelQuantilePredictionsStrategy(ModelPredictionsStrategyBase): def get_quantile_idx(self, quantile): return self.quantiles.index(quantile) + + +class ConcatenatedStrategies(StrategyBase): + """ + Evaluates multiple strategies, + each on the next `window_size` data points. + """ + + def __init__(self, window_size, strategies, name='Concatenated Strategy'): + self.window_size = window_size + self.strategies = strategies + self.name = name + + def info(self): + return {'strategy_name': self.name} + + def run(self, data): + chunks = [data[i:i+self.window_size].copy() + for i in range(0, data.shape[0], self.window_size)] + assert len(chunks) <= len(self.strategies) + + positions = [] + for chunk, strategy in zip(chunks, self.strategies): + positions.append(strategy.run(chunk)) + + return np.concatenate(positions)