store predictions for all data parts
This commit is contained in:
parent
290576dbb9
commit
60741b68fc
@ -82,6 +82,28 @@ def get_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def store_predictions(run, part, model, data, batch_size=64):
|
||||
predictions = model.predict(
|
||||
data.to_dataloader(train=False, batch_size=batch_size),
|
||||
mode="raw",
|
||||
return_index=True,
|
||||
trainer_kwargs={
|
||||
'logger': False
|
||||
})
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
for key, value in {
|
||||
'index': predictions.index,
|
||||
'predictions': predictions.output.prediction
|
||||
}.items():
|
||||
torch.save(value, os.path.join(tempdir, key + ".pt"))
|
||||
|
||||
pred_artifact = wandb.Artifact(
|
||||
f"prediction-{part}-{run.id}", type='prediction')
|
||||
pred_artifact.add_dir(tempdir)
|
||||
run.log_artifact(pred_artifact)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=args.log_level)
|
||||
@ -172,28 +194,15 @@ def main():
|
||||
ckpt_path=ckpt_path)
|
||||
|
||||
if not args.no_wandb and args.store_predictions:
|
||||
logging.info("Computing and storing predictions of best model.")
|
||||
test_data = pd.concat(
|
||||
[valid_data[-config['past_window']:], out_of_sample])
|
||||
test = build_time_series_dataset(run.config, test_data)
|
||||
test_preds = model.__class__.load_from_checkpoint(ckpt_path).predict(
|
||||
test.to_dataloader(train=False, batch_size=batch_size),
|
||||
mode="raw",
|
||||
return_index=True,
|
||||
trainer_kwargs={
|
||||
'logger': False
|
||||
})
|
||||
model = model.__class__.load_from_checkpoint(ckpt_path)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
for key, value in {
|
||||
'index': test_preds.index,
|
||||
'predictions': test_preds.output.prediction
|
||||
}.items():
|
||||
torch.save(value, os.path.join(tempdir, key + ".pt"))
|
||||
pred_artifact = wandb.Artifact(
|
||||
f"prediction-{run.id}", type='prediction')
|
||||
pred_artifact.add_dir(tempdir)
|
||||
run.log_artifact(pred_artifact)
|
||||
logging.info("Computing and storing predictions for best model.")
|
||||
store_predictions(run, 'train', model, train, batch_size=batch_size)
|
||||
store_predictions(run, 'valid', model, valid, batch_size=batch_size)
|
||||
store_predictions(run, 'test', model, test, batch_size=batch_size)
|
||||
|
||||
# Clean up models that do not have best/latest tags, to save space on wandb
|
||||
for artifact in wandb.Api().run(run.path).logged_artifacts():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user