diff --git a/pretraining.py b/pretraining.py index 0038450..0705203 100644 --- a/pretraining.py +++ b/pretraining.py @@ -659,7 +659,7 @@ def group_texts(examples): save_model(training_args.output_dir, model, tokenizer, training_args) # Evaluation - if training_args.do_eval and trainer.is_world_process_zero(): + if training_args.do_eval: logger.info("*** Evaluate ***") metrics = trainer.evaluate()