fix training script predictions

This commit is contained in:
Filip Stefaniuk 2024-09-14 20:29:53 +02:00
parent 03d760bdf1
commit 134860d09e

View File

@ -6,6 +6,7 @@ import os
import tempfile
import torch
import lightning.pytorch as pl
import pandas as pd
from lightning.pytorch.utilities.model_summary import ModelSummary
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
@ -172,7 +173,9 @@ def main():
if not args.no_wandb and args.store_predictions:
logging.info("Computing and storing predictions of best model.")
test = build_time_series_dataset(run.config, out_of_sample)
test_data = pd.concat(
[valid_data[-config['past_window']:], out_of_sample])
test = build_time_series_dataset(run.config, test_data)
test_preds = model.__class__.load_from_checkpoint(ckpt_path).predict(
test.to_dataloader(train=False, batch_size=batch_size),
mode="raw",