store predictions for all data parts
This commit is contained in:
parent
290576dbb9
commit
60741b68fc
@ -82,6 +82,28 @@ def get_args():
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def store_predictions(run, part, model, data, batch_size=64):
|
||||||
|
predictions = model.predict(
|
||||||
|
data.to_dataloader(train=False, batch_size=batch_size),
|
||||||
|
mode="raw",
|
||||||
|
return_index=True,
|
||||||
|
trainer_kwargs={
|
||||||
|
'logger': False
|
||||||
|
})
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
for key, value in {
|
||||||
|
'index': predictions.index,
|
||||||
|
'predictions': predictions.output.prediction
|
||||||
|
}.items():
|
||||||
|
torch.save(value, os.path.join(tempdir, key + ".pt"))
|
||||||
|
|
||||||
|
pred_artifact = wandb.Artifact(
|
||||||
|
f"prediction-{part}-{run.id}", type='prediction')
|
||||||
|
pred_artifact.add_dir(tempdir)
|
||||||
|
run.log_artifact(pred_artifact)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
logging.basicConfig(level=args.log_level)
|
logging.basicConfig(level=args.log_level)
|
||||||
@ -172,28 +194,15 @@ def main():
|
|||||||
ckpt_path=ckpt_path)
|
ckpt_path=ckpt_path)
|
||||||
|
|
||||||
if not args.no_wandb and args.store_predictions:
|
if not args.no_wandb and args.store_predictions:
|
||||||
logging.info("Computing and storing predictions of best model.")
|
|
||||||
test_data = pd.concat(
|
test_data = pd.concat(
|
||||||
[valid_data[-config['past_window']:], out_of_sample])
|
[valid_data[-config['past_window']:], out_of_sample])
|
||||||
test = build_time_series_dataset(run.config, test_data)
|
test = build_time_series_dataset(run.config, test_data)
|
||||||
test_preds = model.__class__.load_from_checkpoint(ckpt_path).predict(
|
model = model.__class__.load_from_checkpoint(ckpt_path)
|
||||||
test.to_dataloader(train=False, batch_size=batch_size),
|
|
||||||
mode="raw",
|
|
||||||
return_index=True,
|
|
||||||
trainer_kwargs={
|
|
||||||
'logger': False
|
|
||||||
})
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tempdir:
|
logging.info("Computing and storing predictions for best model.")
|
||||||
for key, value in {
|
store_predictions(run, 'train', model, train, batch_size=batch_size)
|
||||||
'index': test_preds.index,
|
store_predictions(run, 'valid', model, valid, batch_size=batch_size)
|
||||||
'predictions': test_preds.output.prediction
|
store_predictions(run, 'test', model, test, batch_size=batch_size)
|
||||||
}.items():
|
|
||||||
torch.save(value, os.path.join(tempdir, key + ".pt"))
|
|
||||||
pred_artifact = wandb.Artifact(
|
|
||||||
f"prediction-{run.id}", type='prediction')
|
|
||||||
pred_artifact.add_dir(tempdir)
|
|
||||||
run.log_artifact(pred_artifact)
|
|
||||||
|
|
||||||
# Clean up models that do not have best/latest tags, to save space on wandb
|
# Clean up models that do not have best/latest tags, to save space on wandb
|
||||||
for artifact in wandb.Api().run(run.path).logged_artifacts():
|
for artifact in wandb.Api().run(run.path).logged_artifacts():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user