Modify training script with option to save predictions on test set

This commit is contained in:
Filip Stefaniuk 2024-09-14 14:10:29 +02:00
parent 063ea18d00
commit a3c5c76000
3 changed files with 129 additions and 75 deletions

View File

@ -3,18 +3,22 @@ import logging
import wandb
import pprint
import os
import pandas as pd
import tempfile
import torch
import lightning.pytorch as pl
from lightning.pytorch.utilities.model_summary import ModelSummary
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
from pytorch_forecasting import QuantileLoss
from ml.loss import GMADL
from ml.loss import get_loss
from ml.model import get_model
from ml.data import (
get_dataset_from_wandb,
get_train_validation_split,
build_time_series_dataset
)
def get_args():
@ -71,77 +75,80 @@ def get_args():
parser.add_argument('--no-wandb', action='store_true',
help='Disables wandb, for testing.')
parser.add_argument('--store-predictions', action='store_true',
help='Whether to store predictions of the best run.')
return parser.parse_args()
def get_dataset(config, project):
artifact_name = f"{project}/{config['data']['dataset']}"
artifact = wandb.Api().artifact(artifact_name)
base_path = artifact.download()
logging.info(f"Artifacts downloaded to {base_path}")
# def get_dataset(config, project):
# artifact_name = f"{project}/{config['data']['dataset']}"
# artifact = wandb.Api().artifact(artifact_name)
# base_path = artifact.download()
# logging.info(f"Artifacts downloaded to {base_path}")
name = artifact.metadata['name']
part_name = f"in-sample-{config['data']['sliding_window']}"
data = pd.read_csv(os.path.join(
base_path, name + '-' + part_name + '.csv'))
logging.info(f"Using part: {part_name}")
# name = artifact.metadata['name']
# part_name = f"in-sample-{config['data']['sliding_window']}"
# data = pd.read_csv(os.path.join(
# base_path, name + '-' + part_name + '.csv'))
# logging.info(f"Using part: {part_name}")
# TODO: Fix in dataset
data['weekday'] = data['weekday'].astype('str')
data['hour'] = data['hour'].astype('str')
# # TODO: Fix in dataset
# data['weekday'] = data['weekday'].astype('str')
# data['hour'] = data['hour'].astype('str')
validation_part = config['data']['validation']
logging.info(f"Using {validation_part} of in sample part for validation.")
train_data = data.iloc[:int(len(data) * (1 - validation_part))]
val_data = data.iloc[len(train_data) - config['past_window']:]
logging.info(f"Trainin part size: {len(train_data)}")
logging.info(
f"Validation part size: {len(val_data)} "
+ f"({len(data) - len(train_data)} + {config['past_window']})")
# validation_part = config['data']['validation']
# logging.info(f"Using {validation_part} of in sample part for validation.")
# train_data = data.iloc[:int(len(data) * (1 - validation_part))]
# val_data = data.iloc[len(train_data) - config['past_window']:]
# logging.info(f"Trainin part size: {len(train_data)}")
# logging.info(
# f"Validation part size: {len(val_data)} "
# + f"({len(data) - len(train_data)} + {config['past_window']})")
logging.info("Building time series dataset for training.")
train = TimeSeriesDataSet(
train_data,
time_idx=config['data']['fields']['time_index'],
target=config['data']['fields']['target'],
group_ids=config['data']['fields']['group_ids'],
min_encoder_length=config['past_window'],
max_encoder_length=config['past_window'],
min_prediction_length=config['future_window'],
max_prediction_length=config['future_window'],
static_reals=config['data']['fields']['static_real'],
static_categoricals=config['data']['fields']['static_cat'],
time_varying_known_reals=config['data']['fields'][
'dynamic_known_real'],
time_varying_known_categoricals=config['data']['fields'][
'dynamic_known_cat'],
time_varying_unknown_reals=config['data']['fields'][
'dynamic_unknown_real'],
time_varying_unknown_categoricals=config['data']['fields'][
'dynamic_unknown_cat'],
randomize_length=False
)
# logging.info("Building time series dataset for training.")
# train = TimeSeriesDataSet(
# train_data,
# time_idx=config['data']['fields']['time_index'],
# target=config['data']['fields']['target'],
# group_ids=config['data']['fields']['group_ids'],
# min_encoder_length=config['past_window'],
# max_encoder_length=config['past_window'],
# min_prediction_length=config['future_window'],
# max_prediction_length=config['future_window'],
# static_reals=config['data']['fields']['static_real'],
# static_categoricals=config['data']['fields']['static_cat'],
# time_varying_known_reals=config['data']['fields'][
# 'dynamic_known_real'],
# time_varying_known_categoricals=config['data']['fields'][
# 'dynamic_known_cat'],
# time_varying_unknown_reals=config['data']['fields'][
# 'dynamic_unknown_real'],
# time_varying_unknown_categoricals=config['data']['fields'][
# 'dynamic_unknown_cat'],
# randomize_length=False
# )
logging.info("Building time series dataset for validation.")
val = TimeSeriesDataSet.from_dataset(
train, val_data, stop_randomization=True)
# logging.info("Building time series dataset for validation.")
# val = TimeSeriesDataSet.from_dataset(
# train, val_data, stop_randomization=True)
return train, val
# return train, val
def get_loss(config):
loss_name = config['loss']['name']
# def get_loss(config):
# loss_name = config['loss']['name']
if loss_name == 'Quantile':
return QuantileLoss(config['loss']['quantiles'])
# if loss_name == 'Quantile':
# return QuantileLoss(config['loss']['quantiles'])
if loss_name == 'GMADL':
return GMADL(
a=config['loss']['a'],
b=config['loss']['b']
)
# if loss_name == 'GMADL':
# return GMADL(
# a=config['loss']['a'],
# b=config['loss']['b']
# )
raise ValueError("Unknown loss")
# raise ValueError("Unknown loss")
def main():
@ -159,7 +166,10 @@ def main():
logging.info("Using experiment config:\n%s", pprint.pformat(config))
# Get time series dataset
train, valid = get_dataset(config, args.project)
in_sample, out_of_sample = get_dataset_from_wandb(run)
train_data, valid_data = get_train_validation_split(run.config, in_sample)
train = build_time_series_dataset(run.config, train_data)
valid = build_time_series_dataset(run.config, valid_data)
logging.info("Train dataset parameters:\n" +
f"{pprint.pformat(train.get_parameters())}")
@ -230,6 +240,28 @@ def main():
batch_size=batch_size, train=False, num_workers=3),
ckpt_path=ckpt_path)
if not args.no_wandb and args.store_predictions:
logging.info("Computing and storing predictions of best model.")
test = build_time_series_dataset(run.config, out_of_sample)
test_preds = model.__class__.load_from_checkpoint(ckpt_path).predict(
test.to_dataloader(train=False, batch_size=batch_size),
mode="raw",
return_index=True,
trainer_kwargs={
'logger': False
})
with tempfile.TemporaryDirectory() as tempdir:
for key, value in {
'index': test_preds.index,
'predictions': test_preds.output.prediction
}.items():
torch.save(value, os.path.join(tempdir, key + ".pt"))
pred_artifact = wandb.Artifact(
f"prediction-{run.id}", type='prediction')
pred_artifact.add_dir(tempdir)
run.log_artifact(pred_artifact)
# Clean up models that do not have best/latest tags, to save space on wandb
for artifact in wandb.Api().run(run.path).logged_artifacts():
if artifact.type == "model" and not artifact.aliases:

View File

@ -11,18 +11,22 @@ def get_dataset_from_wandb(run, window=None):
base_path = artifact.download()
name = artifact.metadata['name']
in_sample_name = f"in-sample-{window or run.config['data']['sliding_window']}"
in_sample_name =\
f"in-sample-{window or run.config['data']['sliding_window']}"
in_sample_data = pd.read_csv(os.path.join(
base_path, name + '-' + in_sample_name + '.csv'))
out_of_sample_name = f"out-of-sample-{window or run.config['data']['sliding_window']}"
out_of_sample_name =\
f"out-of-sample-{window or run.config['data']['sliding_window']}"
out_of_sample_data = pd.read_csv(os.path.join(
base_path, name + '-' + out_of_sample_name + '.csv'))
return in_sample_data, out_of_sample_data
def get_train_validation_split(config, in_sample_data):
validation_part = config['data']['validation']
train_data = in_sample_data.iloc[:int(len(in_sample_data) * (1 - validation_part))]
train_data = in_sample_data.iloc[:int(
len(in_sample_data) * (1 - validation_part))]
val_data = in_sample_data.iloc[len(train_data) - config['past_window']:]
return train_data, val_data
@ -30,25 +34,27 @@ def get_train_validation_split(config, in_sample_data):
def build_time_series_dataset(config, data):
data = data.copy()
# TODO: Fix in dataset
data['weekday'] = data['weekday'].astype('str')
data['hour'] = data['hour'].astype('str')
time_series_dataset = TimeSeriesDataSet(
data,
time_idx=config['data']['fields']['time_index'],
target=config['data']['fields']['target'],
group_ids=config['data']['fields']['group_ids'],
time_idx=config['fields']['time_index'],
target=config['fields']['target'],
group_ids=config['fields']['group_ids'],
min_encoder_length=config['past_window'],
max_encoder_length=config['past_window'],
min_prediction_length=config['future_window'],
max_prediction_length=config['future_window'],
static_reals=config['data']['fields']['static_real'],
static_categoricals=config['data']['fields']['static_cat'],
time_varying_known_reals=config['data']['fields']['dynamic_known_real'],
time_varying_known_categoricals=config['data']['fields']['dynamic_known_cat'],
time_varying_unknown_reals=config['data']['fields']['dynamic_unknown_real'],
time_varying_unknown_categoricals=config['data']['fields']['dynamic_unknown_cat'],
static_reals=config['fields']['static_real'],
static_categoricals=config['fields']['static_cat'],
time_varying_known_reals=config['fields']['dynamic_known_real'],
time_varying_known_categoricals=config['fields']['dynamic_known_cat'],
time_varying_unknown_reals=config['fields']['dynamic_unknown_real'],
time_varying_unknown_categoricals=config['fields'][
'dynamic_unknown_cat'],
randomize_length=False,
)
return time_series_dataset
return time_series_dataset

View File

@ -1,8 +1,24 @@
import torch
from pytorch_forecasting import QuantileLoss
from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric
def get_loss(config):
loss_name = config['loss']['name']
if loss_name == 'Quantile':
return QuantileLoss(config['loss']['quantiles'])
if loss_name == 'GMADL':
return GMADL(
a=config['loss']['a'],
b=config['loss']['b']
)
raise ValueError("Unknown loss")
class GMADL(MultiHorizonMetric):
"""GMADL loss function."""