From 60741b68fc9d0b3609350dfd7c8947bfa4aba67b Mon Sep 17 00:00:00 2001 From: Filip Stefaniuk Date: Mon, 16 Sep 2024 20:28:45 +0200 Subject: [PATCH] store predictions for all data parts --- scripts/train.py | 45 +++++++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index 7192d12..15ed1cc 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -82,6 +82,28 @@ def get_args(): return parser.parse_args() +def store_predictions(run, part, model, data, batch_size=64): + predictions = model.predict( + data.to_dataloader(train=False, batch_size=batch_size), + mode="raw", + return_index=True, + trainer_kwargs={ + 'logger': False + }) + + with tempfile.TemporaryDirectory() as tempdir: + for key, value in { + 'index': predictions.index, + 'predictions': predictions.output.prediction + }.items(): + torch.save(value, os.path.join(tempdir, key + ".pt")) + + pred_artifact = wandb.Artifact( + f"prediction-{part}-{run.id}", type='prediction') + pred_artifact.add_dir(tempdir) + run.log_artifact(pred_artifact) + + def main(): args = get_args() logging.basicConfig(level=args.log_level) @@ -172,28 +194,15 @@ def main(): ckpt_path=ckpt_path) if not args.no_wandb and args.store_predictions: - logging.info("Computing and storing predictions of best model.") test_data = pd.concat( [valid_data[-config['past_window']:], out_of_sample]) test = build_time_series_dataset(run.config, test_data) - test_preds = model.__class__.load_from_checkpoint(ckpt_path).predict( - test.to_dataloader(train=False, batch_size=batch_size), - mode="raw", - return_index=True, - trainer_kwargs={ - 'logger': False - }) + model = model.__class__.load_from_checkpoint(ckpt_path) - with tempfile.TemporaryDirectory() as tempdir: - for key, value in { - 'index': test_preds.index, - 'predictions': test_preds.output.prediction - }.items(): - torch.save(value, os.path.join(tempdir, key + ".pt")) - pred_artifact = wandb.Artifact( - f"prediction-{run.id}", type='prediction') - pred_artifact.add_dir(tempdir) - run.log_artifact(pred_artifact) + logging.info("Computing and storing predictions for best model.") + store_predictions(run, 'train', model, train, batch_size=batch_size) + store_predictions(run, 'valid', model, valid, batch_size=batch_size) + store_predictions(run, 'test', model, test, batch_size=batch_size) # Clean up models that do not have best/latest tags, to save space on wandb for artifact in wandb.Api().run(run.path).logged_artifacts():