fix training script predictions
This commit is contained in:
parent
03d760bdf1
commit
134860d09e
@ -6,6 +6,7 @@ import os
|
||||
import tempfile
|
||||
import torch
|
||||
import lightning.pytorch as pl
|
||||
import pandas as pd
|
||||
|
||||
from lightning.pytorch.utilities.model_summary import ModelSummary
|
||||
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
||||
@ -172,7 +173,9 @@ def main():
|
||||
|
||||
if not args.no_wandb and args.store_predictions:
|
||||
logging.info("Computing and storing predictions of best model.")
|
||||
test = build_time_series_dataset(run.config, out_of_sample)
|
||||
test_data = pd.concat(
|
||||
[valid_data[-config['past_window']:], out_of_sample])
|
||||
test = build_time_series_dataset(run.config, test_data)
|
||||
test_preds = model.__class__.load_from_checkpoint(ckpt_path).predict(
|
||||
test.to_dataloader(train=False, batch_size=batch_size),
|
||||
mode="raw",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user