diff --git a/scripts/train.py b/scripts/train.py index e1a2462..5990d3c 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -59,6 +59,14 @@ def get_args(): help="Run validation every n batches." ) + parser.add_argument( + '-p', + '--patience', + default=8, + type=int, + help="Patience for early stopping." + ) + parser.add_argument('--no-wandb', action='store_true', help='Disables wandb, for testing.') @@ -191,7 +199,7 @@ def main(): early_stopping = EarlyStopping( monitor="val_loss", mode="min", - patience=5) + patience=args.patience) batch_size = config['batch_size'] logging.info(f"Training batch size {batch_size}.") @@ -223,6 +231,13 @@ def main(): batch_size=batch_size, train=False, num_workers=3 )) + # Run validation with best model to log min val_loss + # TODO: Maybe use different metric like min_val_loss ? + ckpt_path = trainer.checkpoint_callback.best_model_path or None + trainer.validate(model, dataloaders=valid.to_dataloader( + batch_size=batch_size, train=False, num_workers=3), + ckpt_path=ckpt_path) + if __name__ == '__main__': main()