refactor notebooks and add evaluation for quantile models
This commit is contained in:
parent
cb3c3588f2
commit
9109be5776
@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
||||
167213
notebooks/evaluate.ipynb
167213
notebooks/evaluate.ipynb
File diff suppressed because one or more lines are too long
54
src/ml/data.py
Normal file
54
src/ml/data.py
Normal file
@ -0,0 +1,54 @@
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
import wandb
|
||||
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
|
||||
|
||||
|
||||
def get_dataset_from_wandb(run, window=None):
|
||||
artifact_name = f"{run.project}/{run.config['data']['dataset']}"
|
||||
artifact = wandb.Api().artifact(artifact_name)
|
||||
base_path = artifact.download()
|
||||
|
||||
name = artifact.metadata['name']
|
||||
in_sample_name = f"in-sample-{window or run.config['data']['sliding_window']}"
|
||||
in_sample_data = pd.read_csv(os.path.join(
|
||||
base_path, name + '-' + in_sample_name + '.csv'))
|
||||
out_of_sample_name = f"out-of-sample-{window or run.config['data']['sliding_window']}"
|
||||
out_of_sample_data = pd.read_csv(os.path.join(
|
||||
base_path, name + '-' + out_of_sample_name + '.csv'))
|
||||
|
||||
return in_sample_data, out_of_sample_data
|
||||
|
||||
def get_train_validation_split(config, in_sample_data):
|
||||
validation_part = config['data']['validation']
|
||||
train_data = in_sample_data.iloc[:int(len(in_sample_data) * (1 - validation_part))]
|
||||
val_data = in_sample_data.iloc[len(train_data) - config['past_window']:]
|
||||
|
||||
return train_data, val_data
|
||||
|
||||
|
||||
def build_time_series_dataset(config, data):
|
||||
data = data.copy()
|
||||
data['weekday'] = data['weekday'].astype('str')
|
||||
data['hour'] = data['hour'].astype('str')
|
||||
|
||||
time_series_dataset = TimeSeriesDataSet(
|
||||
data,
|
||||
time_idx=config['data']['fields']['time_index'],
|
||||
target=config['data']['fields']['target'],
|
||||
group_ids=config['data']['fields']['group_ids'],
|
||||
min_encoder_length=config['past_window'],
|
||||
max_encoder_length=config['past_window'],
|
||||
min_prediction_length=config['future_window'],
|
||||
max_prediction_length=config['future_window'],
|
||||
static_reals=config['data']['fields']['static_real'],
|
||||
static_categoricals=config['data']['fields']['static_cat'],
|
||||
time_varying_known_reals=config['data']['fields']['dynamic_known_real'],
|
||||
time_varying_known_categoricals=config['data']['fields']['dynamic_known_cat'],
|
||||
time_varying_unknown_reals=config['data']['fields']['dynamic_unknown_real'],
|
||||
time_varying_unknown_categoricals=config['data']['fields']['dynamic_unknown_cat'],
|
||||
randomize_length=False,
|
||||
)
|
||||
|
||||
return time_series_dataset
|
||||
13
src/ml/model.py
Normal file
13
src/ml/model.py
Normal file
@ -0,0 +1,13 @@
|
||||
import wandb
|
||||
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer
|
||||
|
||||
# TODO: Maybe save all models on cpu
|
||||
def load_model_from_wandb(run):
|
||||
model_name = run.config['model']['name']
|
||||
model_path = f"{run.project}/model-{run.id}:best"
|
||||
model_artifact = wandb.Api().artifact(model_path)
|
||||
|
||||
if model_name == 'TemporalFusionTransformer':
|
||||
return TemporalFusionTransformer.load_from_checkpoint(model_artifact.file())
|
||||
|
||||
raise ValueError("Invalid model name")
|
||||
@ -34,7 +34,36 @@ class BuyAndHoldStrategy(StrategyBase):
|
||||
dtype=np.int32)
|
||||
|
||||
|
||||
class ModelReturnsPredictionStrategy(StrategyBase):
|
||||
class ModelPredictionsStrategyBase(StrategyBase):
|
||||
"""Base class for strategies based on model predictions."""
|
||||
def __init__(self,
|
||||
predictions,
|
||||
name: str = None):
|
||||
self.predictions = predictions
|
||||
assert 'time_index' in self.predictions.columns
|
||||
assert 'group_id' in self.predictions.columns
|
||||
assert 'prediction' in self.predictions.columns
|
||||
|
||||
self.name = name
|
||||
|
||||
|
||||
def info(self):
|
||||
return {'strategy_name': self.name or 'Unknown model'}
|
||||
|
||||
def run(self, data):
|
||||
# Adds predictions to data, if prediction is unknown for a given
|
||||
# item it will be nan.
|
||||
merged_data = pd.merge(
|
||||
data, self.predictions, on=['time_index', 'group_id'],
|
||||
how='left')
|
||||
|
||||
return self.get_positions(merged_data)
|
||||
|
||||
def get_positions(self, data):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ReturnsPredictionStrategy(ModelPredictionsStrategyBase):
|
||||
"""Strategy that selects position based on returns predictions."""
|
||||
|
||||
def __init__(
|
||||
@ -42,22 +71,11 @@ class ModelReturnsPredictionStrategy(StrategyBase):
|
||||
predictions,
|
||||
threshold=0.001,
|
||||
name=None):
|
||||
self.predictions = predictions
|
||||
assert 'time_index' in self.predictions.columns
|
||||
assert 'group_id' in self.predictions.columns
|
||||
assert 'prediction' in self.predictions.columns
|
||||
|
||||
self.name = name or "ML Returns prediction"
|
||||
super().__init__(predictions, name=name)
|
||||
self.threshold = threshold
|
||||
|
||||
def info(self) -> Dict[str, Any]:
|
||||
return {'strategy_name': self.name}
|
||||
|
||||
def run(self, data):
|
||||
arr = pd.merge(
|
||||
data, self.predictions, on=['time_index', 'group_id'],
|
||||
how='left')['prediction'].to_numpy()
|
||||
|
||||
def get_positions(self, data):
|
||||
arr = data['prediction']
|
||||
positions = []
|
||||
for i in range(len(arr)):
|
||||
if arr[i] > self.threshold:
|
||||
@ -70,3 +88,39 @@ class ModelReturnsPredictionStrategy(StrategyBase):
|
||||
positions.append(positions[-1])
|
||||
|
||||
return np.array(positions, dtype=np.int32)
|
||||
|
||||
|
||||
class PriceQuantilePredictionStrategy(ModelPredictionsStrategyBase):
|
||||
def __init__(
|
||||
self,
|
||||
predictions,
|
||||
name=None):
|
||||
super().__init__(predictions, name=name)
|
||||
|
||||
def info(self):
|
||||
return {'strategy_name': self.name}
|
||||
|
||||
def get_positions(self, data):
|
||||
|
||||
arr_preds = data['prediction'].to_numpy()
|
||||
arr_close_price = data['close_price'].to_numpy()
|
||||
|
||||
positions = []
|
||||
for i in range(len(arr_preds)):
|
||||
if not np.isnan(arr_preds[i]).any():
|
||||
price = arr_close_price[i]
|
||||
pred_low = arr_preds[i][0]
|
||||
pred_high = arr_preds[i][-1]
|
||||
if (pred_low - price) / price > 0.001:
|
||||
positions.append(LONG_POSITION)
|
||||
continue
|
||||
elif (pred_high - price) / price < -0.001:
|
||||
positions.append(EXIT_POSITION)
|
||||
continue
|
||||
|
||||
if not len(positions):
|
||||
positions.append(EXIT_POSITION)
|
||||
else:
|
||||
positions.append(positions[-1])
|
||||
|
||||
return np.array(positions, dtype=np.int32)
|
||||
Loading…
x
Reference in New Issue
Block a user