diff --git a/train.py b/train.py index ff03ac9..391d379 100644 --- a/train.py +++ b/train.py @@ -134,11 +134,11 @@ def main(args): state["best_loss"] = val_loss state["best_checkpoint"] = checkpoint_path + state["epochs"] += 1 # CHECKPOINT print("Saving model...") model_utils.save_model(model, optimizer, state, checkpoint_path) - state["epochs"] += 1 #### TESTING #### # Test loss