Modify training script with option to save predictions on test set

This commit is contained in:
Filip Stefaniuk 2024-09-14 14:10:29 +02:00
parent 063ea18d00
commit a3c5c76000
3 changed files with 129 additions and 75 deletions

View File

@ -3,18 +3,22 @@ import logging
import wandb import wandb
import pprint import pprint
import os import os
import pandas as pd import tempfile
import torch
import lightning.pytorch as pl import lightning.pytorch as pl
from lightning.pytorch.utilities.model_summary import ModelSummary from lightning.pytorch.utilities.model_summary import ModelSummary
from lightning.pytorch.callbacks.early_stopping import EarlyStopping from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
from pytorch_forecasting import QuantileLoss
from ml.loss import GMADL from ml.loss import get_loss
from ml.model import get_model from ml.model import get_model
from ml.data import (
get_dataset_from_wandb,
get_train_validation_split,
build_time_series_dataset
)
def get_args(): def get_args():
@ -71,77 +75,80 @@ def get_args():
parser.add_argument('--no-wandb', action='store_true', parser.add_argument('--no-wandb', action='store_true',
help='Disables wandb, for testing.') help='Disables wandb, for testing.')
parser.add_argument('--store-predictions', action='store_true',
help='Whether to store predictions of the best run.')
return parser.parse_args() return parser.parse_args()
def get_dataset(config, project): # def get_dataset(config, project):
artifact_name = f"{project}/{config['data']['dataset']}" # artifact_name = f"{project}/{config['data']['dataset']}"
artifact = wandb.Api().artifact(artifact_name) # artifact = wandb.Api().artifact(artifact_name)
base_path = artifact.download() # base_path = artifact.download()
logging.info(f"Artifacts downloaded to {base_path}") # logging.info(f"Artifacts downloaded to {base_path}")
name = artifact.metadata['name'] # name = artifact.metadata['name']
part_name = f"in-sample-{config['data']['sliding_window']}" # part_name = f"in-sample-{config['data']['sliding_window']}"
data = pd.read_csv(os.path.join( # data = pd.read_csv(os.path.join(
base_path, name + '-' + part_name + '.csv')) # base_path, name + '-' + part_name + '.csv'))
logging.info(f"Using part: {part_name}") # logging.info(f"Using part: {part_name}")
# TODO: Fix in dataset # # TODO: Fix in dataset
data['weekday'] = data['weekday'].astype('str') # data['weekday'] = data['weekday'].astype('str')
data['hour'] = data['hour'].astype('str') # data['hour'] = data['hour'].astype('str')
validation_part = config['data']['validation'] # validation_part = config['data']['validation']
logging.info(f"Using {validation_part} of in sample part for validation.") # logging.info(f"Using {validation_part} of in sample part for validation.")
train_data = data.iloc[:int(len(data) * (1 - validation_part))] # train_data = data.iloc[:int(len(data) * (1 - validation_part))]
val_data = data.iloc[len(train_data) - config['past_window']:] # val_data = data.iloc[len(train_data) - config['past_window']:]
logging.info(f"Trainin part size: {len(train_data)}") # logging.info(f"Trainin part size: {len(train_data)}")
logging.info( # logging.info(
f"Validation part size: {len(val_data)} " # f"Validation part size: {len(val_data)} "
+ f"({len(data) - len(train_data)} + {config['past_window']})") # + f"({len(data) - len(train_data)} + {config['past_window']})")
logging.info("Building time series dataset for training.") # logging.info("Building time series dataset for training.")
train = TimeSeriesDataSet( # train = TimeSeriesDataSet(
train_data, # train_data,
time_idx=config['data']['fields']['time_index'], # time_idx=config['data']['fields']['time_index'],
target=config['data']['fields']['target'], # target=config['data']['fields']['target'],
group_ids=config['data']['fields']['group_ids'], # group_ids=config['data']['fields']['group_ids'],
min_encoder_length=config['past_window'], # min_encoder_length=config['past_window'],
max_encoder_length=config['past_window'], # max_encoder_length=config['past_window'],
min_prediction_length=config['future_window'], # min_prediction_length=config['future_window'],
max_prediction_length=config['future_window'], # max_prediction_length=config['future_window'],
static_reals=config['data']['fields']['static_real'], # static_reals=config['data']['fields']['static_real'],
static_categoricals=config['data']['fields']['static_cat'], # static_categoricals=config['data']['fields']['static_cat'],
time_varying_known_reals=config['data']['fields'][ # time_varying_known_reals=config['data']['fields'][
'dynamic_known_real'], # 'dynamic_known_real'],
time_varying_known_categoricals=config['data']['fields'][ # time_varying_known_categoricals=config['data']['fields'][
'dynamic_known_cat'], # 'dynamic_known_cat'],
time_varying_unknown_reals=config['data']['fields'][ # time_varying_unknown_reals=config['data']['fields'][
'dynamic_unknown_real'], # 'dynamic_unknown_real'],
time_varying_unknown_categoricals=config['data']['fields'][ # time_varying_unknown_categoricals=config['data']['fields'][
'dynamic_unknown_cat'], # 'dynamic_unknown_cat'],
randomize_length=False # randomize_length=False
) # )
logging.info("Building time series dataset for validation.") # logging.info("Building time series dataset for validation.")
val = TimeSeriesDataSet.from_dataset( # val = TimeSeriesDataSet.from_dataset(
train, val_data, stop_randomization=True) # train, val_data, stop_randomization=True)
return train, val # return train, val
def get_loss(config): # def get_loss(config):
loss_name = config['loss']['name'] # loss_name = config['loss']['name']
if loss_name == 'Quantile': # if loss_name == 'Quantile':
return QuantileLoss(config['loss']['quantiles']) # return QuantileLoss(config['loss']['quantiles'])
if loss_name == 'GMADL': # if loss_name == 'GMADL':
return GMADL( # return GMADL(
a=config['loss']['a'], # a=config['loss']['a'],
b=config['loss']['b'] # b=config['loss']['b']
) # )
raise ValueError("Unknown loss") # raise ValueError("Unknown loss")
def main(): def main():
@ -159,7 +166,10 @@ def main():
logging.info("Using experiment config:\n%s", pprint.pformat(config)) logging.info("Using experiment config:\n%s", pprint.pformat(config))
# Get time series dataset # Get time series dataset
train, valid = get_dataset(config, args.project) in_sample, out_of_sample = get_dataset_from_wandb(run)
train_data, valid_data = get_train_validation_split(run.config, in_sample)
train = build_time_series_dataset(run.config, train_data)
valid = build_time_series_dataset(run.config, valid_data)
logging.info("Train dataset parameters:\n" + logging.info("Train dataset parameters:\n" +
f"{pprint.pformat(train.get_parameters())}") f"{pprint.pformat(train.get_parameters())}")
@ -230,6 +240,28 @@ def main():
batch_size=batch_size, train=False, num_workers=3), batch_size=batch_size, train=False, num_workers=3),
ckpt_path=ckpt_path) ckpt_path=ckpt_path)
if not args.no_wandb and args.store_predictions:
logging.info("Computing and storing predictions of best model.")
test = build_time_series_dataset(run.config, out_of_sample)
test_preds = model.__class__.load_from_checkpoint(ckpt_path).predict(
test.to_dataloader(train=False, batch_size=batch_size),
mode="raw",
return_index=True,
trainer_kwargs={
'logger': False
})
with tempfile.TemporaryDirectory() as tempdir:
for key, value in {
'index': test_preds.index,
'predictions': test_preds.output.prediction
}.items():
torch.save(value, os.path.join(tempdir, key + ".pt"))
pred_artifact = wandb.Artifact(
f"prediction-{run.id}", type='prediction')
pred_artifact.add_dir(tempdir)
run.log_artifact(pred_artifact)
# Clean up models that do not have best/latest tags, to save space on wandb # Clean up models that do not have best/latest tags, to save space on wandb
for artifact in wandb.Api().run(run.path).logged_artifacts(): for artifact in wandb.Api().run(run.path).logged_artifacts():
if artifact.type == "model" and not artifact.aliases: if artifact.type == "model" and not artifact.aliases:

View File

@ -11,18 +11,22 @@ def get_dataset_from_wandb(run, window=None):
base_path = artifact.download() base_path = artifact.download()
name = artifact.metadata['name'] name = artifact.metadata['name']
in_sample_name = f"in-sample-{window or run.config['data']['sliding_window']}" in_sample_name =\
f"in-sample-{window or run.config['data']['sliding_window']}"
in_sample_data = pd.read_csv(os.path.join( in_sample_data = pd.read_csv(os.path.join(
base_path, name + '-' + in_sample_name + '.csv')) base_path, name + '-' + in_sample_name + '.csv'))
out_of_sample_name = f"out-of-sample-{window or run.config['data']['sliding_window']}" out_of_sample_name =\
f"out-of-sample-{window or run.config['data']['sliding_window']}"
out_of_sample_data = pd.read_csv(os.path.join( out_of_sample_data = pd.read_csv(os.path.join(
base_path, name + '-' + out_of_sample_name + '.csv')) base_path, name + '-' + out_of_sample_name + '.csv'))
return in_sample_data, out_of_sample_data return in_sample_data, out_of_sample_data
def get_train_validation_split(config, in_sample_data): def get_train_validation_split(config, in_sample_data):
validation_part = config['data']['validation'] validation_part = config['data']['validation']
train_data = in_sample_data.iloc[:int(len(in_sample_data) * (1 - validation_part))] train_data = in_sample_data.iloc[:int(
len(in_sample_data) * (1 - validation_part))]
val_data = in_sample_data.iloc[len(train_data) - config['past_window']:] val_data = in_sample_data.iloc[len(train_data) - config['past_window']:]
return train_data, val_data return train_data, val_data
@ -30,25 +34,27 @@ def get_train_validation_split(config, in_sample_data):
def build_time_series_dataset(config, data): def build_time_series_dataset(config, data):
data = data.copy() data = data.copy()
# TODO: Fix in dataset
data['weekday'] = data['weekday'].astype('str') data['weekday'] = data['weekday'].astype('str')
data['hour'] = data['hour'].astype('str') data['hour'] = data['hour'].astype('str')
time_series_dataset = TimeSeriesDataSet( time_series_dataset = TimeSeriesDataSet(
data, data,
time_idx=config['data']['fields']['time_index'], time_idx=config['fields']['time_index'],
target=config['data']['fields']['target'], target=config['fields']['target'],
group_ids=config['data']['fields']['group_ids'], group_ids=config['fields']['group_ids'],
min_encoder_length=config['past_window'], min_encoder_length=config['past_window'],
max_encoder_length=config['past_window'], max_encoder_length=config['past_window'],
min_prediction_length=config['future_window'], min_prediction_length=config['future_window'],
max_prediction_length=config['future_window'], max_prediction_length=config['future_window'],
static_reals=config['data']['fields']['static_real'], static_reals=config['fields']['static_real'],
static_categoricals=config['data']['fields']['static_cat'], static_categoricals=config['fields']['static_cat'],
time_varying_known_reals=config['data']['fields']['dynamic_known_real'], time_varying_known_reals=config['fields']['dynamic_known_real'],
time_varying_known_categoricals=config['data']['fields']['dynamic_known_cat'], time_varying_known_categoricals=config['fields']['dynamic_known_cat'],
time_varying_unknown_reals=config['data']['fields']['dynamic_unknown_real'], time_varying_unknown_reals=config['fields']['dynamic_unknown_real'],
time_varying_unknown_categoricals=config['data']['fields']['dynamic_unknown_cat'], time_varying_unknown_categoricals=config['fields'][
'dynamic_unknown_cat'],
randomize_length=False, randomize_length=False,
) )
return time_series_dataset return time_series_dataset

View File

@ -1,8 +1,24 @@
import torch import torch
from pytorch_forecasting import QuantileLoss
from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric
def get_loss(config):
loss_name = config['loss']['name']
if loss_name == 'Quantile':
return QuantileLoss(config['loss']['quantiles'])
if loss_name == 'GMADL':
return GMADL(
a=config['loss']['a'],
b=config['loss']['b']
)
raise ValueError("Unknown loss")
class GMADL(MultiHorizonMetric): class GMADL(MultiHorizonMetric):
"""GMADL loss function.""" """GMADL loss function."""