From a3c5c760003ebc617268aac8454942b78e0aba4c Mon Sep 17 00:00:00 2001 From: Filip Stefaniuk Date: Sat, 14 Sep 2024 14:10:29 +0200 Subject: [PATCH] Modify training script with option to save predictions on test set --- scripts/train.py | 156 ++++++++++++++++++++++++++++------------------- src/ml/data.py | 32 ++++++---- src/ml/loss.py | 16 +++++ 3 files changed, 129 insertions(+), 75 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index ba3a811..89b116f 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -3,18 +3,22 @@ import logging import wandb import pprint import os -import pandas as pd +import tempfile +import torch import lightning.pytorch as pl from lightning.pytorch.utilities.model_summary import ModelSummary from lightning.pytorch.callbacks.early_stopping import EarlyStopping from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.callbacks import ModelCheckpoint -from pytorch_forecasting.data.timeseries import TimeSeriesDataSet -from pytorch_forecasting import QuantileLoss -from ml.loss import GMADL +from ml.loss import get_loss from ml.model import get_model +from ml.data import ( + get_dataset_from_wandb, + get_train_validation_split, + build_time_series_dataset +) def get_args(): @@ -71,77 +75,80 @@ def get_args(): parser.add_argument('--no-wandb', action='store_true', help='Disables wandb, for testing.') + parser.add_argument('--store-predictions', action='store_true', + help='Whether to store predictions of the best run.') + return parser.parse_args() -def get_dataset(config, project): - artifact_name = f"{project}/{config['data']['dataset']}" - artifact = wandb.Api().artifact(artifact_name) - base_path = artifact.download() - logging.info(f"Artifacts downloaded to {base_path}") +# def get_dataset(config, project): +# artifact_name = f"{project}/{config['data']['dataset']}" +# artifact = wandb.Api().artifact(artifact_name) +# base_path = artifact.download() +# logging.info(f"Artifacts downloaded to {base_path}") - name = artifact.metadata['name'] - part_name = f"in-sample-{config['data']['sliding_window']}" - data = pd.read_csv(os.path.join( - base_path, name + '-' + part_name + '.csv')) - logging.info(f"Using part: {part_name}") +# name = artifact.metadata['name'] +# part_name = f"in-sample-{config['data']['sliding_window']}" +# data = pd.read_csv(os.path.join( +# base_path, name + '-' + part_name + '.csv')) +# logging.info(f"Using part: {part_name}") - # TODO: Fix in dataset - data['weekday'] = data['weekday'].astype('str') - data['hour'] = data['hour'].astype('str') +# # TODO: Fix in dataset +# data['weekday'] = data['weekday'].astype('str') +# data['hour'] = data['hour'].astype('str') - validation_part = config['data']['validation'] - logging.info(f"Using {validation_part} of in sample part for validation.") - train_data = data.iloc[:int(len(data) * (1 - validation_part))] - val_data = data.iloc[len(train_data) - config['past_window']:] - logging.info(f"Trainin part size: {len(train_data)}") - logging.info( - f"Validation part size: {len(val_data)} " - + f"({len(data) - len(train_data)} + {config['past_window']})") +# validation_part = config['data']['validation'] +# logging.info(f"Using {validation_part} of in sample part for validation.") +# train_data = data.iloc[:int(len(data) * (1 - validation_part))] +# val_data = data.iloc[len(train_data) - config['past_window']:] +# logging.info(f"Trainin part size: {len(train_data)}") +# logging.info( +# f"Validation part size: {len(val_data)} " +# + f"({len(data) - len(train_data)} + {config['past_window']})") - logging.info("Building time series dataset for training.") - train = TimeSeriesDataSet( - train_data, - time_idx=config['data']['fields']['time_index'], - target=config['data']['fields']['target'], - group_ids=config['data']['fields']['group_ids'], - min_encoder_length=config['past_window'], - max_encoder_length=config['past_window'], - min_prediction_length=config['future_window'], - max_prediction_length=config['future_window'], - static_reals=config['data']['fields']['static_real'], - static_categoricals=config['data']['fields']['static_cat'], - time_varying_known_reals=config['data']['fields'][ - 'dynamic_known_real'], - time_varying_known_categoricals=config['data']['fields'][ - 'dynamic_known_cat'], - time_varying_unknown_reals=config['data']['fields'][ - 'dynamic_unknown_real'], - time_varying_unknown_categoricals=config['data']['fields'][ - 'dynamic_unknown_cat'], - randomize_length=False - ) +# logging.info("Building time series dataset for training.") +# train = TimeSeriesDataSet( +# train_data, +# time_idx=config['data']['fields']['time_index'], +# target=config['data']['fields']['target'], +# group_ids=config['data']['fields']['group_ids'], +# min_encoder_length=config['past_window'], +# max_encoder_length=config['past_window'], +# min_prediction_length=config['future_window'], +# max_prediction_length=config['future_window'], +# static_reals=config['data']['fields']['static_real'], +# static_categoricals=config['data']['fields']['static_cat'], +# time_varying_known_reals=config['data']['fields'][ +# 'dynamic_known_real'], +# time_varying_known_categoricals=config['data']['fields'][ +# 'dynamic_known_cat'], +# time_varying_unknown_reals=config['data']['fields'][ +# 'dynamic_unknown_real'], +# time_varying_unknown_categoricals=config['data']['fields'][ +# 'dynamic_unknown_cat'], +# randomize_length=False +# ) - logging.info("Building time series dataset for validation.") - val = TimeSeriesDataSet.from_dataset( - train, val_data, stop_randomization=True) +# logging.info("Building time series dataset for validation.") +# val = TimeSeriesDataSet.from_dataset( +# train, val_data, stop_randomization=True) - return train, val +# return train, val -def get_loss(config): - loss_name = config['loss']['name'] +# def get_loss(config): +# loss_name = config['loss']['name'] - if loss_name == 'Quantile': - return QuantileLoss(config['loss']['quantiles']) +# if loss_name == 'Quantile': +# return QuantileLoss(config['loss']['quantiles']) - if loss_name == 'GMADL': - return GMADL( - a=config['loss']['a'], - b=config['loss']['b'] - ) +# if loss_name == 'GMADL': +# return GMADL( +# a=config['loss']['a'], +# b=config['loss']['b'] +# ) - raise ValueError("Unknown loss") +# raise ValueError("Unknown loss") def main(): @@ -159,7 +166,10 @@ def main(): logging.info("Using experiment config:\n%s", pprint.pformat(config)) # Get time series dataset - train, valid = get_dataset(config, args.project) + in_sample, out_of_sample = get_dataset_from_wandb(run) + train_data, valid_data = get_train_validation_split(run.config, in_sample) + train = build_time_series_dataset(run.config, train_data) + valid = build_time_series_dataset(run.config, valid_data) logging.info("Train dataset parameters:\n" + f"{pprint.pformat(train.get_parameters())}") @@ -230,6 +240,28 @@ def main(): batch_size=batch_size, train=False, num_workers=3), ckpt_path=ckpt_path) + if not args.no_wandb and args.store_predictions: + logging.info("Computing and storing predictions of best model.") + test = build_time_series_dataset(run.config, out_of_sample) + test_preds = model.__class__.load_from_checkpoint(ckpt_path).predict( + test.to_dataloader(train=False, batch_size=batch_size), + mode="raw", + return_index=True, + trainer_kwargs={ + 'logger': False + }) + + with tempfile.TemporaryDirectory() as tempdir: + for key, value in { + 'index': test_preds.index, + 'predictions': test_preds.output.prediction + }.items(): + torch.save(value, os.path.join(tempdir, key + ".pt")) + pred_artifact = wandb.Artifact( + f"prediction-{run.id}", type='prediction') + pred_artifact.add_dir(tempdir) + run.log_artifact(pred_artifact) + # Clean up models that do not have best/latest tags, to save space on wandb for artifact in wandb.Api().run(run.path).logged_artifacts(): if artifact.type == "model" and not artifact.aliases: diff --git a/src/ml/data.py b/src/ml/data.py index 2edabf7..b84cdbc 100644 --- a/src/ml/data.py +++ b/src/ml/data.py @@ -11,18 +11,22 @@ def get_dataset_from_wandb(run, window=None): base_path = artifact.download() name = artifact.metadata['name'] - in_sample_name = f"in-sample-{window or run.config['data']['sliding_window']}" + in_sample_name =\ + f"in-sample-{window or run.config['data']['sliding_window']}" in_sample_data = pd.read_csv(os.path.join( base_path, name + '-' + in_sample_name + '.csv')) - out_of_sample_name = f"out-of-sample-{window or run.config['data']['sliding_window']}" + out_of_sample_name =\ + f"out-of-sample-{window or run.config['data']['sliding_window']}" out_of_sample_data = pd.read_csv(os.path.join( base_path, name + '-' + out_of_sample_name + '.csv')) return in_sample_data, out_of_sample_data + def get_train_validation_split(config, in_sample_data): validation_part = config['data']['validation'] - train_data = in_sample_data.iloc[:int(len(in_sample_data) * (1 - validation_part))] + train_data = in_sample_data.iloc[:int( + len(in_sample_data) * (1 - validation_part))] val_data = in_sample_data.iloc[len(train_data) - config['past_window']:] return train_data, val_data @@ -30,25 +34,27 @@ def get_train_validation_split(config, in_sample_data): def build_time_series_dataset(config, data): data = data.copy() + # TODO: Fix in dataset data['weekday'] = data['weekday'].astype('str') data['hour'] = data['hour'].astype('str') time_series_dataset = TimeSeriesDataSet( data, - time_idx=config['data']['fields']['time_index'], - target=config['data']['fields']['target'], - group_ids=config['data']['fields']['group_ids'], + time_idx=config['fields']['time_index'], + target=config['fields']['target'], + group_ids=config['fields']['group_ids'], min_encoder_length=config['past_window'], max_encoder_length=config['past_window'], min_prediction_length=config['future_window'], max_prediction_length=config['future_window'], - static_reals=config['data']['fields']['static_real'], - static_categoricals=config['data']['fields']['static_cat'], - time_varying_known_reals=config['data']['fields']['dynamic_known_real'], - time_varying_known_categoricals=config['data']['fields']['dynamic_known_cat'], - time_varying_unknown_reals=config['data']['fields']['dynamic_unknown_real'], - time_varying_unknown_categoricals=config['data']['fields']['dynamic_unknown_cat'], + static_reals=config['fields']['static_real'], + static_categoricals=config['fields']['static_cat'], + time_varying_known_reals=config['fields']['dynamic_known_real'], + time_varying_known_categoricals=config['fields']['dynamic_known_cat'], + time_varying_unknown_reals=config['fields']['dynamic_unknown_real'], + time_varying_unknown_categoricals=config['fields'][ + 'dynamic_unknown_cat'], randomize_length=False, ) - return time_series_dataset \ No newline at end of file + return time_series_dataset diff --git a/src/ml/loss.py b/src/ml/loss.py index af168f7..e2f1942 100644 --- a/src/ml/loss.py +++ b/src/ml/loss.py @@ -1,8 +1,24 @@ import torch +from pytorch_forecasting import QuantileLoss from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric +def get_loss(config): + loss_name = config['loss']['name'] + + if loss_name == 'Quantile': + return QuantileLoss(config['loss']['quantiles']) + + if loss_name == 'GMADL': + return GMADL( + a=config['loss']['a'], + b=config['loss']['b'] + ) + + raise ValueError("Unknown loss") + + class GMADL(MultiHorizonMetric): """GMADL loss function."""