add final evaluation after training

This commit is contained in:
Filip Stefaniuk 2024-09-04 17:02:07 +02:00
parent ffc93b0ae5
commit dcd35508fd

View File

@ -59,6 +59,14 @@ def get_args():
help="Run validation every n batches." help="Run validation every n batches."
) )
parser.add_argument(
'-p',
'--patience',
default=8,
type=int,
help="Patience for early stopping."
)
parser.add_argument('--no-wandb', action='store_true', parser.add_argument('--no-wandb', action='store_true',
help='Disables wandb, for testing.') help='Disables wandb, for testing.')
@ -191,7 +199,7 @@ def main():
early_stopping = EarlyStopping( early_stopping = EarlyStopping(
monitor="val_loss", monitor="val_loss",
mode="min", mode="min",
patience=5) patience=args.patience)
batch_size = config['batch_size'] batch_size = config['batch_size']
logging.info(f"Training batch size {batch_size}.") logging.info(f"Training batch size {batch_size}.")
@ -223,6 +231,13 @@ def main():
batch_size=batch_size, train=False, num_workers=3 batch_size=batch_size, train=False, num_workers=3
)) ))
# Run validation with best model to log min val_loss
# TODO: Maybe use different metric like min_val_loss ?
ckpt_path = trainer.checkpoint_callback.best_model_path or None
trainer.validate(model, dataloaders=valid.to_dataloader(
batch_size=batch_size, train=False, num_workers=3),
ckpt_path=ckpt_path)
if __name__ == '__main__': if __name__ == '__main__':
main() main()