From 134860d09e0db279d06c84933e2cf61c2437b208 Mon Sep 17 00:00:00 2001 From: Filip Stefaniuk Date: Sat, 14 Sep 2024 20:29:53 +0200 Subject: [PATCH] fix training script predictions --- scripts/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/train.py b/scripts/train.py index d09f065..7192d12 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -6,6 +6,7 @@ import os import tempfile import torch import lightning.pytorch as pl +import pandas as pd from lightning.pytorch.utilities.model_summary import ModelSummary from lightning.pytorch.callbacks.early_stopping import EarlyStopping @@ -172,7 +173,9 @@ def main(): if not args.no_wandb and args.store_predictions: logging.info("Computing and storing predictions of best model.") - test = build_time_series_dataset(run.config, out_of_sample) + test_data = pd.concat( + [valid_data[-config['past_window']:], out_of_sample]) + test = build_time_series_dataset(run.config, test_data) test_preds = model.__class__.load_from_checkpoint(ckpt_path).predict( test.to_dataloader(train=False, batch_size=batch_size), mode="raw",