Modify training script with option to save predictions on test set
This commit is contained in:
parent
063ea18d00
commit
a3c5c76000
156
scripts/train.py
156
scripts/train.py
@ -3,18 +3,22 @@ import logging
|
||||
import wandb
|
||||
import pprint
|
||||
import os
|
||||
import pandas as pd
|
||||
import tempfile
|
||||
import torch
|
||||
import lightning.pytorch as pl
|
||||
|
||||
from lightning.pytorch.utilities.model_summary import ModelSummary
|
||||
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
|
||||
from pytorch_forecasting import QuantileLoss
|
||||
|
||||
from ml.loss import GMADL
|
||||
from ml.loss import get_loss
|
||||
from ml.model import get_model
|
||||
from ml.data import (
|
||||
get_dataset_from_wandb,
|
||||
get_train_validation_split,
|
||||
build_time_series_dataset
|
||||
)
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -71,77 +75,80 @@ def get_args():
|
||||
parser.add_argument('--no-wandb', action='store_true',
|
||||
help='Disables wandb, for testing.')
|
||||
|
||||
parser.add_argument('--store-predictions', action='store_true',
|
||||
help='Whether to store predictions of the best run.')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_dataset(config, project):
|
||||
artifact_name = f"{project}/{config['data']['dataset']}"
|
||||
artifact = wandb.Api().artifact(artifact_name)
|
||||
base_path = artifact.download()
|
||||
logging.info(f"Artifacts downloaded to {base_path}")
|
||||
# def get_dataset(config, project):
|
||||
# artifact_name = f"{project}/{config['data']['dataset']}"
|
||||
# artifact = wandb.Api().artifact(artifact_name)
|
||||
# base_path = artifact.download()
|
||||
# logging.info(f"Artifacts downloaded to {base_path}")
|
||||
|
||||
name = artifact.metadata['name']
|
||||
part_name = f"in-sample-{config['data']['sliding_window']}"
|
||||
data = pd.read_csv(os.path.join(
|
||||
base_path, name + '-' + part_name + '.csv'))
|
||||
logging.info(f"Using part: {part_name}")
|
||||
# name = artifact.metadata['name']
|
||||
# part_name = f"in-sample-{config['data']['sliding_window']}"
|
||||
# data = pd.read_csv(os.path.join(
|
||||
# base_path, name + '-' + part_name + '.csv'))
|
||||
# logging.info(f"Using part: {part_name}")
|
||||
|
||||
# TODO: Fix in dataset
|
||||
data['weekday'] = data['weekday'].astype('str')
|
||||
data['hour'] = data['hour'].astype('str')
|
||||
# # TODO: Fix in dataset
|
||||
# data['weekday'] = data['weekday'].astype('str')
|
||||
# data['hour'] = data['hour'].astype('str')
|
||||
|
||||
validation_part = config['data']['validation']
|
||||
logging.info(f"Using {validation_part} of in sample part for validation.")
|
||||
train_data = data.iloc[:int(len(data) * (1 - validation_part))]
|
||||
val_data = data.iloc[len(train_data) - config['past_window']:]
|
||||
logging.info(f"Trainin part size: {len(train_data)}")
|
||||
logging.info(
|
||||
f"Validation part size: {len(val_data)} "
|
||||
+ f"({len(data) - len(train_data)} + {config['past_window']})")
|
||||
# validation_part = config['data']['validation']
|
||||
# logging.info(f"Using {validation_part} of in sample part for validation.")
|
||||
# train_data = data.iloc[:int(len(data) * (1 - validation_part))]
|
||||
# val_data = data.iloc[len(train_data) - config['past_window']:]
|
||||
# logging.info(f"Trainin part size: {len(train_data)}")
|
||||
# logging.info(
|
||||
# f"Validation part size: {len(val_data)} "
|
||||
# + f"({len(data) - len(train_data)} + {config['past_window']})")
|
||||
|
||||
logging.info("Building time series dataset for training.")
|
||||
train = TimeSeriesDataSet(
|
||||
train_data,
|
||||
time_idx=config['data']['fields']['time_index'],
|
||||
target=config['data']['fields']['target'],
|
||||
group_ids=config['data']['fields']['group_ids'],
|
||||
min_encoder_length=config['past_window'],
|
||||
max_encoder_length=config['past_window'],
|
||||
min_prediction_length=config['future_window'],
|
||||
max_prediction_length=config['future_window'],
|
||||
static_reals=config['data']['fields']['static_real'],
|
||||
static_categoricals=config['data']['fields']['static_cat'],
|
||||
time_varying_known_reals=config['data']['fields'][
|
||||
'dynamic_known_real'],
|
||||
time_varying_known_categoricals=config['data']['fields'][
|
||||
'dynamic_known_cat'],
|
||||
time_varying_unknown_reals=config['data']['fields'][
|
||||
'dynamic_unknown_real'],
|
||||
time_varying_unknown_categoricals=config['data']['fields'][
|
||||
'dynamic_unknown_cat'],
|
||||
randomize_length=False
|
||||
)
|
||||
# logging.info("Building time series dataset for training.")
|
||||
# train = TimeSeriesDataSet(
|
||||
# train_data,
|
||||
# time_idx=config['data']['fields']['time_index'],
|
||||
# target=config['data']['fields']['target'],
|
||||
# group_ids=config['data']['fields']['group_ids'],
|
||||
# min_encoder_length=config['past_window'],
|
||||
# max_encoder_length=config['past_window'],
|
||||
# min_prediction_length=config['future_window'],
|
||||
# max_prediction_length=config['future_window'],
|
||||
# static_reals=config['data']['fields']['static_real'],
|
||||
# static_categoricals=config['data']['fields']['static_cat'],
|
||||
# time_varying_known_reals=config['data']['fields'][
|
||||
# 'dynamic_known_real'],
|
||||
# time_varying_known_categoricals=config['data']['fields'][
|
||||
# 'dynamic_known_cat'],
|
||||
# time_varying_unknown_reals=config['data']['fields'][
|
||||
# 'dynamic_unknown_real'],
|
||||
# time_varying_unknown_categoricals=config['data']['fields'][
|
||||
# 'dynamic_unknown_cat'],
|
||||
# randomize_length=False
|
||||
# )
|
||||
|
||||
logging.info("Building time series dataset for validation.")
|
||||
val = TimeSeriesDataSet.from_dataset(
|
||||
train, val_data, stop_randomization=True)
|
||||
# logging.info("Building time series dataset for validation.")
|
||||
# val = TimeSeriesDataSet.from_dataset(
|
||||
# train, val_data, stop_randomization=True)
|
||||
|
||||
return train, val
|
||||
# return train, val
|
||||
|
||||
|
||||
def get_loss(config):
|
||||
loss_name = config['loss']['name']
|
||||
# def get_loss(config):
|
||||
# loss_name = config['loss']['name']
|
||||
|
||||
if loss_name == 'Quantile':
|
||||
return QuantileLoss(config['loss']['quantiles'])
|
||||
# if loss_name == 'Quantile':
|
||||
# return QuantileLoss(config['loss']['quantiles'])
|
||||
|
||||
if loss_name == 'GMADL':
|
||||
return GMADL(
|
||||
a=config['loss']['a'],
|
||||
b=config['loss']['b']
|
||||
)
|
||||
# if loss_name == 'GMADL':
|
||||
# return GMADL(
|
||||
# a=config['loss']['a'],
|
||||
# b=config['loss']['b']
|
||||
# )
|
||||
|
||||
raise ValueError("Unknown loss")
|
||||
# raise ValueError("Unknown loss")
|
||||
|
||||
|
||||
def main():
|
||||
@ -159,7 +166,10 @@ def main():
|
||||
logging.info("Using experiment config:\n%s", pprint.pformat(config))
|
||||
|
||||
# Get time series dataset
|
||||
train, valid = get_dataset(config, args.project)
|
||||
in_sample, out_of_sample = get_dataset_from_wandb(run)
|
||||
train_data, valid_data = get_train_validation_split(run.config, in_sample)
|
||||
train = build_time_series_dataset(run.config, train_data)
|
||||
valid = build_time_series_dataset(run.config, valid_data)
|
||||
logging.info("Train dataset parameters:\n" +
|
||||
f"{pprint.pformat(train.get_parameters())}")
|
||||
|
||||
@ -230,6 +240,28 @@ def main():
|
||||
batch_size=batch_size, train=False, num_workers=3),
|
||||
ckpt_path=ckpt_path)
|
||||
|
||||
if not args.no_wandb and args.store_predictions:
|
||||
logging.info("Computing and storing predictions of best model.")
|
||||
test = build_time_series_dataset(run.config, out_of_sample)
|
||||
test_preds = model.__class__.load_from_checkpoint(ckpt_path).predict(
|
||||
test.to_dataloader(train=False, batch_size=batch_size),
|
||||
mode="raw",
|
||||
return_index=True,
|
||||
trainer_kwargs={
|
||||
'logger': False
|
||||
})
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
for key, value in {
|
||||
'index': test_preds.index,
|
||||
'predictions': test_preds.output.prediction
|
||||
}.items():
|
||||
torch.save(value, os.path.join(tempdir, key + ".pt"))
|
||||
pred_artifact = wandb.Artifact(
|
||||
f"prediction-{run.id}", type='prediction')
|
||||
pred_artifact.add_dir(tempdir)
|
||||
run.log_artifact(pred_artifact)
|
||||
|
||||
# Clean up models that do not have best/latest tags, to save space on wandb
|
||||
for artifact in wandb.Api().run(run.path).logged_artifacts():
|
||||
if artifact.type == "model" and not artifact.aliases:
|
||||
|
||||
@ -11,18 +11,22 @@ def get_dataset_from_wandb(run, window=None):
|
||||
base_path = artifact.download()
|
||||
|
||||
name = artifact.metadata['name']
|
||||
in_sample_name = f"in-sample-{window or run.config['data']['sliding_window']}"
|
||||
in_sample_name =\
|
||||
f"in-sample-{window or run.config['data']['sliding_window']}"
|
||||
in_sample_data = pd.read_csv(os.path.join(
|
||||
base_path, name + '-' + in_sample_name + '.csv'))
|
||||
out_of_sample_name = f"out-of-sample-{window or run.config['data']['sliding_window']}"
|
||||
out_of_sample_name =\
|
||||
f"out-of-sample-{window or run.config['data']['sliding_window']}"
|
||||
out_of_sample_data = pd.read_csv(os.path.join(
|
||||
base_path, name + '-' + out_of_sample_name + '.csv'))
|
||||
|
||||
return in_sample_data, out_of_sample_data
|
||||
|
||||
|
||||
def get_train_validation_split(config, in_sample_data):
|
||||
validation_part = config['data']['validation']
|
||||
train_data = in_sample_data.iloc[:int(len(in_sample_data) * (1 - validation_part))]
|
||||
train_data = in_sample_data.iloc[:int(
|
||||
len(in_sample_data) * (1 - validation_part))]
|
||||
val_data = in_sample_data.iloc[len(train_data) - config['past_window']:]
|
||||
|
||||
return train_data, val_data
|
||||
@ -30,25 +34,27 @@ def get_train_validation_split(config, in_sample_data):
|
||||
|
||||
def build_time_series_dataset(config, data):
|
||||
data = data.copy()
|
||||
# TODO: Fix in dataset
|
||||
data['weekday'] = data['weekday'].astype('str')
|
||||
data['hour'] = data['hour'].astype('str')
|
||||
|
||||
time_series_dataset = TimeSeriesDataSet(
|
||||
data,
|
||||
time_idx=config['data']['fields']['time_index'],
|
||||
target=config['data']['fields']['target'],
|
||||
group_ids=config['data']['fields']['group_ids'],
|
||||
time_idx=config['fields']['time_index'],
|
||||
target=config['fields']['target'],
|
||||
group_ids=config['fields']['group_ids'],
|
||||
min_encoder_length=config['past_window'],
|
||||
max_encoder_length=config['past_window'],
|
||||
min_prediction_length=config['future_window'],
|
||||
max_prediction_length=config['future_window'],
|
||||
static_reals=config['data']['fields']['static_real'],
|
||||
static_categoricals=config['data']['fields']['static_cat'],
|
||||
time_varying_known_reals=config['data']['fields']['dynamic_known_real'],
|
||||
time_varying_known_categoricals=config['data']['fields']['dynamic_known_cat'],
|
||||
time_varying_unknown_reals=config['data']['fields']['dynamic_unknown_real'],
|
||||
time_varying_unknown_categoricals=config['data']['fields']['dynamic_unknown_cat'],
|
||||
static_reals=config['fields']['static_real'],
|
||||
static_categoricals=config['fields']['static_cat'],
|
||||
time_varying_known_reals=config['fields']['dynamic_known_real'],
|
||||
time_varying_known_categoricals=config['fields']['dynamic_known_cat'],
|
||||
time_varying_unknown_reals=config['fields']['dynamic_unknown_real'],
|
||||
time_varying_unknown_categoricals=config['fields'][
|
||||
'dynamic_unknown_cat'],
|
||||
randomize_length=False,
|
||||
)
|
||||
|
||||
return time_series_dataset
|
||||
return time_series_dataset
|
||||
|
||||
@ -1,8 +1,24 @@
|
||||
import torch
|
||||
|
||||
from pytorch_forecasting import QuantileLoss
|
||||
from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric
|
||||
|
||||
|
||||
def get_loss(config):
|
||||
loss_name = config['loss']['name']
|
||||
|
||||
if loss_name == 'Quantile':
|
||||
return QuantileLoss(config['loss']['quantiles'])
|
||||
|
||||
if loss_name == 'GMADL':
|
||||
return GMADL(
|
||||
a=config['loss']['a'],
|
||||
b=config['loss']['b']
|
||||
)
|
||||
|
||||
raise ValueError("Unknown loss")
|
||||
|
||||
|
||||
class GMADL(MultiHorizonMetric):
|
||||
"""GMADL loss function."""
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user