274 lines
8.5 KiB
Python

import argparse
import logging
import wandb
import pprint
import os
import tempfile
import torch
import lightning.pytorch as pl
from lightning.pytorch.utilities.model_summary import ModelSummary
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from ml.loss import get_loss
from ml.model import get_model
from ml.data import (
get_dataset_from_wandb,
get_train_validation_split,
build_time_series_dataset
)
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"config", help="Experiment configuration file in yaml format.")
parser.add_argument(
"-p",
"--project",
default="wne-masters-thesis-testing",
help="W&B project name.")
parser.add_argument(
"-l",
"--log-level",
default=logging.INFO,
type=int,
help="Sets the log level.")
parser.add_argument(
"-s",
"--seed",
default=42,
type=int,
help="Random seed for the training.")
parser.add_argument(
"-n",
"--log-interval",
default=100,
type=int,
help="Log every n steps."
)
parser.add_argument(
'-v',
'--val-check-interval',
default=300,
type=int,
help="Run validation every n batches."
)
parser.add_argument(
'-t',
'--patience',
default=5,
type=int,
help="Patience for early stopping."
)
parser.add_argument('--no-wandb', action='store_true',
help='Disables wandb, for testing.')
parser.add_argument('--store-predictions', action='store_true',
help='Whether to store predictions of the best run.')
return parser.parse_args()
# def get_dataset(config, project):
# artifact_name = f"{project}/{config['data']['dataset']}"
# artifact = wandb.Api().artifact(artifact_name)
# base_path = artifact.download()
# logging.info(f"Artifacts downloaded to {base_path}")
# name = artifact.metadata['name']
# part_name = f"in-sample-{config['data']['sliding_window']}"
# data = pd.read_csv(os.path.join(
# base_path, name + '-' + part_name + '.csv'))
# logging.info(f"Using part: {part_name}")
# # TODO: Fix in dataset
# data['weekday'] = data['weekday'].astype('str')
# data['hour'] = data['hour'].astype('str')
# validation_part = config['data']['validation']
# logging.info(f"Using {validation_part} of in sample part for validation.")
# train_data = data.iloc[:int(len(data) * (1 - validation_part))]
# val_data = data.iloc[len(train_data) - config['past_window']:]
# logging.info(f"Trainin part size: {len(train_data)}")
# logging.info(
# f"Validation part size: {len(val_data)} "
# + f"({len(data) - len(train_data)} + {config['past_window']})")
# logging.info("Building time series dataset for training.")
# train = TimeSeriesDataSet(
# train_data,
# time_idx=config['data']['fields']['time_index'],
# target=config['data']['fields']['target'],
# group_ids=config['data']['fields']['group_ids'],
# min_encoder_length=config['past_window'],
# max_encoder_length=config['past_window'],
# min_prediction_length=config['future_window'],
# max_prediction_length=config['future_window'],
# static_reals=config['data']['fields']['static_real'],
# static_categoricals=config['data']['fields']['static_cat'],
# time_varying_known_reals=config['data']['fields'][
# 'dynamic_known_real'],
# time_varying_known_categoricals=config['data']['fields'][
# 'dynamic_known_cat'],
# time_varying_unknown_reals=config['data']['fields'][
# 'dynamic_unknown_real'],
# time_varying_unknown_categoricals=config['data']['fields'][
# 'dynamic_unknown_cat'],
# randomize_length=False
# )
# logging.info("Building time series dataset for validation.")
# val = TimeSeriesDataSet.from_dataset(
# train, val_data, stop_randomization=True)
# return train, val
# def get_loss(config):
# loss_name = config['loss']['name']
# if loss_name == 'Quantile':
# return QuantileLoss(config['loss']['quantiles'])
# if loss_name == 'GMADL':
# return GMADL(
# a=config['loss']['a'],
# b=config['loss']['b']
# )
# raise ValueError("Unknown loss")
def main():
args = get_args()
logging.basicConfig(level=args.log_level)
pl.seed_everything(args.seed)
run = wandb.init(
project=args.project,
config=args.config,
job_type="train",
mode="disabled" if args.no_wandb else "online"
)
config = run.config
logging.info("Using experiment config:\n%s", pprint.pformat(config))
# Get time series dataset
in_sample, out_of_sample = get_dataset_from_wandb(run)
train_data, valid_data = get_train_validation_split(run.config, in_sample)
train = build_time_series_dataset(run.config, train_data)
valid = build_time_series_dataset(run.config, valid_data)
logging.info("Train dataset parameters:\n" +
f"{pprint.pformat(train.get_parameters())}")
# Get loss
loss = get_loss(config)
logging.info(f"Using loss {loss}")
# Get model
model = get_model(config, train, loss)
logging.info(f"Using model {config['model']['name']}")
logging.info(f"{ModelSummary(model)}")
logging.info(
"Model hyperparameters:\n" +
f"{pprint.pformat(model.hparams)}")
# Checkpoint for saving the model
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
save_top_k=3,
mode='min',
)
# Logger for W&B
wandb_logger = WandbLogger(
project=args.project,
experiment=run,
log_model="all") if not args.no_wandb else None
early_stopping = EarlyStopping(
monitor="val_loss",
mode="min",
patience=args.patience)
batch_size = config['batch_size']
logging.info(f"Training batch size {batch_size}.")
epochs = config['max_epochs']
logging.info(f"Training for {epochs} epochs.")
trainer = pl.Trainer(
accelerator="auto",
max_epochs=epochs,
logger=wandb_logger,
callbacks=[
checkpoint_callback,
early_stopping
],
log_every_n_steps=args.log_interval,
val_check_interval=args.val_check_interval
)
if epochs > 0:
logging.info("Starting training:")
trainer.fit(
model,
train_dataloaders=train.to_dataloader(
batch_size=batch_size,
num_workers=3,
),
val_dataloaders=valid.to_dataloader(
batch_size=batch_size, train=False, num_workers=3
))
# Run validation with best model to log min val_loss
# TODO: Maybe use different metric like min_val_loss ?
ckpt_path = trainer.checkpoint_callback.best_model_path or None
trainer.validate(model, dataloaders=valid.to_dataloader(
batch_size=batch_size, train=False, num_workers=3),
ckpt_path=ckpt_path)
if not args.no_wandb and args.store_predictions:
logging.info("Computing and storing predictions of best model.")
test = build_time_series_dataset(run.config, out_of_sample)
test_preds = model.__class__.load_from_checkpoint(ckpt_path).predict(
test.to_dataloader(train=False, batch_size=batch_size),
mode="raw",
return_index=True,
trainer_kwargs={
'logger': False
})
with tempfile.TemporaryDirectory() as tempdir:
for key, value in {
'index': test_preds.index,
'predictions': test_preds.output.prediction
}.items():
torch.save(value, os.path.join(tempdir, key + ".pt"))
pred_artifact = wandb.Artifact(
f"prediction-{run.id}", type='prediction')
pred_artifact.add_dir(tempdir)
run.log_artifact(pred_artifact)
# Clean up models that do not have best/latest tags, to save space on wandb
for artifact in wandb.Api().run(run.path).logged_artifacts():
if artifact.type == "model" and not artifact.aliases:
logging.info(f"Deleting artifact {artifact.name}")
artifact.delete()
if __name__ == '__main__':
main()