Modify training script with option to save predictions on test set
This commit is contained in:
parent
063ea18d00
commit
a3c5c76000
156
scripts/train.py
156
scripts/train.py
@ -3,18 +3,22 @@ import logging
|
|||||||
import wandb
|
import wandb
|
||||||
import pprint
|
import pprint
|
||||||
import os
|
import os
|
||||||
import pandas as pd
|
import tempfile
|
||||||
|
import torch
|
||||||
import lightning.pytorch as pl
|
import lightning.pytorch as pl
|
||||||
|
|
||||||
from lightning.pytorch.utilities.model_summary import ModelSummary
|
from lightning.pytorch.utilities.model_summary import ModelSummary
|
||||||
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
||||||
from lightning.pytorch.loggers import WandbLogger
|
from lightning.pytorch.loggers import WandbLogger
|
||||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||||
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
|
|
||||||
from pytorch_forecasting import QuantileLoss
|
|
||||||
|
|
||||||
from ml.loss import GMADL
|
from ml.loss import get_loss
|
||||||
from ml.model import get_model
|
from ml.model import get_model
|
||||||
|
from ml.data import (
|
||||||
|
get_dataset_from_wandb,
|
||||||
|
get_train_validation_split,
|
||||||
|
build_time_series_dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -71,77 +75,80 @@ def get_args():
|
|||||||
parser.add_argument('--no-wandb', action='store_true',
|
parser.add_argument('--no-wandb', action='store_true',
|
||||||
help='Disables wandb, for testing.')
|
help='Disables wandb, for testing.')
|
||||||
|
|
||||||
|
parser.add_argument('--store-predictions', action='store_true',
|
||||||
|
help='Whether to store predictions of the best run.')
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(config, project):
|
# def get_dataset(config, project):
|
||||||
artifact_name = f"{project}/{config['data']['dataset']}"
|
# artifact_name = f"{project}/{config['data']['dataset']}"
|
||||||
artifact = wandb.Api().artifact(artifact_name)
|
# artifact = wandb.Api().artifact(artifact_name)
|
||||||
base_path = artifact.download()
|
# base_path = artifact.download()
|
||||||
logging.info(f"Artifacts downloaded to {base_path}")
|
# logging.info(f"Artifacts downloaded to {base_path}")
|
||||||
|
|
||||||
name = artifact.metadata['name']
|
# name = artifact.metadata['name']
|
||||||
part_name = f"in-sample-{config['data']['sliding_window']}"
|
# part_name = f"in-sample-{config['data']['sliding_window']}"
|
||||||
data = pd.read_csv(os.path.join(
|
# data = pd.read_csv(os.path.join(
|
||||||
base_path, name + '-' + part_name + '.csv'))
|
# base_path, name + '-' + part_name + '.csv'))
|
||||||
logging.info(f"Using part: {part_name}")
|
# logging.info(f"Using part: {part_name}")
|
||||||
|
|
||||||
# TODO: Fix in dataset
|
# # TODO: Fix in dataset
|
||||||
data['weekday'] = data['weekday'].astype('str')
|
# data['weekday'] = data['weekday'].astype('str')
|
||||||
data['hour'] = data['hour'].astype('str')
|
# data['hour'] = data['hour'].astype('str')
|
||||||
|
|
||||||
validation_part = config['data']['validation']
|
# validation_part = config['data']['validation']
|
||||||
logging.info(f"Using {validation_part} of in sample part for validation.")
|
# logging.info(f"Using {validation_part} of in sample part for validation.")
|
||||||
train_data = data.iloc[:int(len(data) * (1 - validation_part))]
|
# train_data = data.iloc[:int(len(data) * (1 - validation_part))]
|
||||||
val_data = data.iloc[len(train_data) - config['past_window']:]
|
# val_data = data.iloc[len(train_data) - config['past_window']:]
|
||||||
logging.info(f"Trainin part size: {len(train_data)}")
|
# logging.info(f"Trainin part size: {len(train_data)}")
|
||||||
logging.info(
|
# logging.info(
|
||||||
f"Validation part size: {len(val_data)} "
|
# f"Validation part size: {len(val_data)} "
|
||||||
+ f"({len(data) - len(train_data)} + {config['past_window']})")
|
# + f"({len(data) - len(train_data)} + {config['past_window']})")
|
||||||
|
|
||||||
logging.info("Building time series dataset for training.")
|
# logging.info("Building time series dataset for training.")
|
||||||
train = TimeSeriesDataSet(
|
# train = TimeSeriesDataSet(
|
||||||
train_data,
|
# train_data,
|
||||||
time_idx=config['data']['fields']['time_index'],
|
# time_idx=config['data']['fields']['time_index'],
|
||||||
target=config['data']['fields']['target'],
|
# target=config['data']['fields']['target'],
|
||||||
group_ids=config['data']['fields']['group_ids'],
|
# group_ids=config['data']['fields']['group_ids'],
|
||||||
min_encoder_length=config['past_window'],
|
# min_encoder_length=config['past_window'],
|
||||||
max_encoder_length=config['past_window'],
|
# max_encoder_length=config['past_window'],
|
||||||
min_prediction_length=config['future_window'],
|
# min_prediction_length=config['future_window'],
|
||||||
max_prediction_length=config['future_window'],
|
# max_prediction_length=config['future_window'],
|
||||||
static_reals=config['data']['fields']['static_real'],
|
# static_reals=config['data']['fields']['static_real'],
|
||||||
static_categoricals=config['data']['fields']['static_cat'],
|
# static_categoricals=config['data']['fields']['static_cat'],
|
||||||
time_varying_known_reals=config['data']['fields'][
|
# time_varying_known_reals=config['data']['fields'][
|
||||||
'dynamic_known_real'],
|
# 'dynamic_known_real'],
|
||||||
time_varying_known_categoricals=config['data']['fields'][
|
# time_varying_known_categoricals=config['data']['fields'][
|
||||||
'dynamic_known_cat'],
|
# 'dynamic_known_cat'],
|
||||||
time_varying_unknown_reals=config['data']['fields'][
|
# time_varying_unknown_reals=config['data']['fields'][
|
||||||
'dynamic_unknown_real'],
|
# 'dynamic_unknown_real'],
|
||||||
time_varying_unknown_categoricals=config['data']['fields'][
|
# time_varying_unknown_categoricals=config['data']['fields'][
|
||||||
'dynamic_unknown_cat'],
|
# 'dynamic_unknown_cat'],
|
||||||
randomize_length=False
|
# randomize_length=False
|
||||||
)
|
# )
|
||||||
|
|
||||||
logging.info("Building time series dataset for validation.")
|
# logging.info("Building time series dataset for validation.")
|
||||||
val = TimeSeriesDataSet.from_dataset(
|
# val = TimeSeriesDataSet.from_dataset(
|
||||||
train, val_data, stop_randomization=True)
|
# train, val_data, stop_randomization=True)
|
||||||
|
|
||||||
return train, val
|
# return train, val
|
||||||
|
|
||||||
|
|
||||||
def get_loss(config):
|
# def get_loss(config):
|
||||||
loss_name = config['loss']['name']
|
# loss_name = config['loss']['name']
|
||||||
|
|
||||||
if loss_name == 'Quantile':
|
# if loss_name == 'Quantile':
|
||||||
return QuantileLoss(config['loss']['quantiles'])
|
# return QuantileLoss(config['loss']['quantiles'])
|
||||||
|
|
||||||
if loss_name == 'GMADL':
|
# if loss_name == 'GMADL':
|
||||||
return GMADL(
|
# return GMADL(
|
||||||
a=config['loss']['a'],
|
# a=config['loss']['a'],
|
||||||
b=config['loss']['b']
|
# b=config['loss']['b']
|
||||||
)
|
# )
|
||||||
|
|
||||||
raise ValueError("Unknown loss")
|
# raise ValueError("Unknown loss")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -159,7 +166,10 @@ def main():
|
|||||||
logging.info("Using experiment config:\n%s", pprint.pformat(config))
|
logging.info("Using experiment config:\n%s", pprint.pformat(config))
|
||||||
|
|
||||||
# Get time series dataset
|
# Get time series dataset
|
||||||
train, valid = get_dataset(config, args.project)
|
in_sample, out_of_sample = get_dataset_from_wandb(run)
|
||||||
|
train_data, valid_data = get_train_validation_split(run.config, in_sample)
|
||||||
|
train = build_time_series_dataset(run.config, train_data)
|
||||||
|
valid = build_time_series_dataset(run.config, valid_data)
|
||||||
logging.info("Train dataset parameters:\n" +
|
logging.info("Train dataset parameters:\n" +
|
||||||
f"{pprint.pformat(train.get_parameters())}")
|
f"{pprint.pformat(train.get_parameters())}")
|
||||||
|
|
||||||
@ -230,6 +240,28 @@ def main():
|
|||||||
batch_size=batch_size, train=False, num_workers=3),
|
batch_size=batch_size, train=False, num_workers=3),
|
||||||
ckpt_path=ckpt_path)
|
ckpt_path=ckpt_path)
|
||||||
|
|
||||||
|
if not args.no_wandb and args.store_predictions:
|
||||||
|
logging.info("Computing and storing predictions of best model.")
|
||||||
|
test = build_time_series_dataset(run.config, out_of_sample)
|
||||||
|
test_preds = model.__class__.load_from_checkpoint(ckpt_path).predict(
|
||||||
|
test.to_dataloader(train=False, batch_size=batch_size),
|
||||||
|
mode="raw",
|
||||||
|
return_index=True,
|
||||||
|
trainer_kwargs={
|
||||||
|
'logger': False
|
||||||
|
})
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
for key, value in {
|
||||||
|
'index': test_preds.index,
|
||||||
|
'predictions': test_preds.output.prediction
|
||||||
|
}.items():
|
||||||
|
torch.save(value, os.path.join(tempdir, key + ".pt"))
|
||||||
|
pred_artifact = wandb.Artifact(
|
||||||
|
f"prediction-{run.id}", type='prediction')
|
||||||
|
pred_artifact.add_dir(tempdir)
|
||||||
|
run.log_artifact(pred_artifact)
|
||||||
|
|
||||||
# Clean up models that do not have best/latest tags, to save space on wandb
|
# Clean up models that do not have best/latest tags, to save space on wandb
|
||||||
for artifact in wandb.Api().run(run.path).logged_artifacts():
|
for artifact in wandb.Api().run(run.path).logged_artifacts():
|
||||||
if artifact.type == "model" and not artifact.aliases:
|
if artifact.type == "model" and not artifact.aliases:
|
||||||
|
|||||||
@ -11,18 +11,22 @@ def get_dataset_from_wandb(run, window=None):
|
|||||||
base_path = artifact.download()
|
base_path = artifact.download()
|
||||||
|
|
||||||
name = artifact.metadata['name']
|
name = artifact.metadata['name']
|
||||||
in_sample_name = f"in-sample-{window or run.config['data']['sliding_window']}"
|
in_sample_name =\
|
||||||
|
f"in-sample-{window or run.config['data']['sliding_window']}"
|
||||||
in_sample_data = pd.read_csv(os.path.join(
|
in_sample_data = pd.read_csv(os.path.join(
|
||||||
base_path, name + '-' + in_sample_name + '.csv'))
|
base_path, name + '-' + in_sample_name + '.csv'))
|
||||||
out_of_sample_name = f"out-of-sample-{window or run.config['data']['sliding_window']}"
|
out_of_sample_name =\
|
||||||
|
f"out-of-sample-{window or run.config['data']['sliding_window']}"
|
||||||
out_of_sample_data = pd.read_csv(os.path.join(
|
out_of_sample_data = pd.read_csv(os.path.join(
|
||||||
base_path, name + '-' + out_of_sample_name + '.csv'))
|
base_path, name + '-' + out_of_sample_name + '.csv'))
|
||||||
|
|
||||||
return in_sample_data, out_of_sample_data
|
return in_sample_data, out_of_sample_data
|
||||||
|
|
||||||
|
|
||||||
def get_train_validation_split(config, in_sample_data):
|
def get_train_validation_split(config, in_sample_data):
|
||||||
validation_part = config['data']['validation']
|
validation_part = config['data']['validation']
|
||||||
train_data = in_sample_data.iloc[:int(len(in_sample_data) * (1 - validation_part))]
|
train_data = in_sample_data.iloc[:int(
|
||||||
|
len(in_sample_data) * (1 - validation_part))]
|
||||||
val_data = in_sample_data.iloc[len(train_data) - config['past_window']:]
|
val_data = in_sample_data.iloc[len(train_data) - config['past_window']:]
|
||||||
|
|
||||||
return train_data, val_data
|
return train_data, val_data
|
||||||
@ -30,24 +34,26 @@ def get_train_validation_split(config, in_sample_data):
|
|||||||
|
|
||||||
def build_time_series_dataset(config, data):
|
def build_time_series_dataset(config, data):
|
||||||
data = data.copy()
|
data = data.copy()
|
||||||
|
# TODO: Fix in dataset
|
||||||
data['weekday'] = data['weekday'].astype('str')
|
data['weekday'] = data['weekday'].astype('str')
|
||||||
data['hour'] = data['hour'].astype('str')
|
data['hour'] = data['hour'].astype('str')
|
||||||
|
|
||||||
time_series_dataset = TimeSeriesDataSet(
|
time_series_dataset = TimeSeriesDataSet(
|
||||||
data,
|
data,
|
||||||
time_idx=config['data']['fields']['time_index'],
|
time_idx=config['fields']['time_index'],
|
||||||
target=config['data']['fields']['target'],
|
target=config['fields']['target'],
|
||||||
group_ids=config['data']['fields']['group_ids'],
|
group_ids=config['fields']['group_ids'],
|
||||||
min_encoder_length=config['past_window'],
|
min_encoder_length=config['past_window'],
|
||||||
max_encoder_length=config['past_window'],
|
max_encoder_length=config['past_window'],
|
||||||
min_prediction_length=config['future_window'],
|
min_prediction_length=config['future_window'],
|
||||||
max_prediction_length=config['future_window'],
|
max_prediction_length=config['future_window'],
|
||||||
static_reals=config['data']['fields']['static_real'],
|
static_reals=config['fields']['static_real'],
|
||||||
static_categoricals=config['data']['fields']['static_cat'],
|
static_categoricals=config['fields']['static_cat'],
|
||||||
time_varying_known_reals=config['data']['fields']['dynamic_known_real'],
|
time_varying_known_reals=config['fields']['dynamic_known_real'],
|
||||||
time_varying_known_categoricals=config['data']['fields']['dynamic_known_cat'],
|
time_varying_known_categoricals=config['fields']['dynamic_known_cat'],
|
||||||
time_varying_unknown_reals=config['data']['fields']['dynamic_unknown_real'],
|
time_varying_unknown_reals=config['fields']['dynamic_unknown_real'],
|
||||||
time_varying_unknown_categoricals=config['data']['fields']['dynamic_unknown_cat'],
|
time_varying_unknown_categoricals=config['fields'][
|
||||||
|
'dynamic_unknown_cat'],
|
||||||
randomize_length=False,
|
randomize_length=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,24 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from pytorch_forecasting import QuantileLoss
|
||||||
from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric
|
from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric
|
||||||
|
|
||||||
|
|
||||||
|
def get_loss(config):
|
||||||
|
loss_name = config['loss']['name']
|
||||||
|
|
||||||
|
if loss_name == 'Quantile':
|
||||||
|
return QuantileLoss(config['loss']['quantiles'])
|
||||||
|
|
||||||
|
if loss_name == 'GMADL':
|
||||||
|
return GMADL(
|
||||||
|
a=config['loss']['a'],
|
||||||
|
b=config['loss']['b']
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError("Unknown loss")
|
||||||
|
|
||||||
|
|
||||||
class GMADL(MultiHorizonMetric):
|
class GMADL(MultiHorizonMetric):
|
||||||
"""GMADL loss function."""
|
"""GMADL loss function."""
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user