fix training script predictions
This commit is contained in:
parent
03d760bdf1
commit
134860d09e
@ -6,6 +6,7 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import torch
|
import torch
|
||||||
import lightning.pytorch as pl
|
import lightning.pytorch as pl
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
from lightning.pytorch.utilities.model_summary import ModelSummary
|
from lightning.pytorch.utilities.model_summary import ModelSummary
|
||||||
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
||||||
@ -172,7 +173,9 @@ def main():
|
|||||||
|
|
||||||
if not args.no_wandb and args.store_predictions:
|
if not args.no_wandb and args.store_predictions:
|
||||||
logging.info("Computing and storing predictions of best model.")
|
logging.info("Computing and storing predictions of best model.")
|
||||||
test = build_time_series_dataset(run.config, out_of_sample)
|
test_data = pd.concat(
|
||||||
|
[valid_data[-config['past_window']:], out_of_sample])
|
||||||
|
test = build_time_series_dataset(run.config, test_data)
|
||||||
test_preds = model.__class__.load_from_checkpoint(ckpt_path).predict(
|
test_preds = model.__class__.load_from_checkpoint(ckpt_path).predict(
|
||||||
test.to_dataloader(train=False, batch_size=batch_size),
|
test.to_dataloader(train=False, batch_size=batch_size),
|
||||||
mode="raw",
|
mode="raw",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user