store predictions for all data parts

This commit is contained in:
Filip Stefaniuk 2024-09-16 20:28:45 +02:00
parent 290576dbb9
commit 60741b68fc

View File

@ -82,6 +82,28 @@ def get_args():
return parser.parse_args() return parser.parse_args()
def store_predictions(run, part, model, data, batch_size=64):
predictions = model.predict(
data.to_dataloader(train=False, batch_size=batch_size),
mode="raw",
return_index=True,
trainer_kwargs={
'logger': False
})
with tempfile.TemporaryDirectory() as tempdir:
for key, value in {
'index': predictions.index,
'predictions': predictions.output.prediction
}.items():
torch.save(value, os.path.join(tempdir, key + ".pt"))
pred_artifact = wandb.Artifact(
f"prediction-{part}-{run.id}", type='prediction')
pred_artifact.add_dir(tempdir)
run.log_artifact(pred_artifact)
def main(): def main():
args = get_args() args = get_args()
logging.basicConfig(level=args.log_level) logging.basicConfig(level=args.log_level)
@ -172,28 +194,15 @@ def main():
ckpt_path=ckpt_path) ckpt_path=ckpt_path)
if not args.no_wandb and args.store_predictions: if not args.no_wandb and args.store_predictions:
logging.info("Computing and storing predictions of best model.")
test_data = pd.concat( test_data = pd.concat(
[valid_data[-config['past_window']:], out_of_sample]) [valid_data[-config['past_window']:], out_of_sample])
test = build_time_series_dataset(run.config, test_data) test = build_time_series_dataset(run.config, test_data)
test_preds = model.__class__.load_from_checkpoint(ckpt_path).predict( model = model.__class__.load_from_checkpoint(ckpt_path)
test.to_dataloader(train=False, batch_size=batch_size),
mode="raw",
return_index=True,
trainer_kwargs={
'logger': False
})
with tempfile.TemporaryDirectory() as tempdir: logging.info("Computing and storing predictions for best model.")
for key, value in { store_predictions(run, 'train', model, train, batch_size=batch_size)
'index': test_preds.index, store_predictions(run, 'valid', model, valid, batch_size=batch_size)
'predictions': test_preds.output.prediction store_predictions(run, 'test', model, test, batch_size=batch_size)
}.items():
torch.save(value, os.path.join(tempdir, key + ".pt"))
pred_artifact = wandb.Artifact(
f"prediction-{run.id}", type='prediction')
pred_artifact.add_dir(tempdir)
run.log_artifact(pred_artifact)
# Clean up models that do not have best/latest tags, to save space on wandb # Clean up models that do not have best/latest tags, to save space on wandb
for artifact in wandb.Api().run(run.path).logged_artifacts(): for artifact in wandb.Api().run(run.path).logged_artifacts():