diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index 31c4f1cb354..bb0bef4e858 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -233,14 +233,17 @@ def create_optimizer(self): if not self.manager: return - n_gpu = ( + + num_devices = ( torch.distributed.get_world_size() if torch.distributed.is_initialized() else self.args._n_gpu ) + if num_devices < 1: + num_devices = 1 total_batch_size = ( self.args.per_device_train_batch_size - * n_gpu + * num_devices * self.args.gradient_accumulation_steps ) self.manager_steps_per_epoch = math.ceil(