From 06ec15c71cb0c368087ac83e466c515817842c71 Mon Sep 17 00:00:00 2001 From: Daniel Campos Date: Wed, 9 Mar 2022 14:57:56 -0600 Subject: [PATCH 1/3] updating batch size for optimizer --- src/sparseml/transformers/sparsification/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index 9a48b1cd589..a385dd56538 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -202,9 +202,12 @@ def create_optimizer(self): if not self.manager: return + num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else self.args._n_gpu + if num_devices < 1: + num_devices = 1 total_batch_size = ( self.args.per_device_train_batch_size - * (self.args._n_gpu or 1) + * num_devices * self.args.gradient_accumulation_steps ) self.manager_steps_per_epoch = math.ceil( From 1b46322aaff65106acdcb07b6466266fdbdde786 Mon Sep 17 00:00:00 2001 From: Daniel Campos Date: Wed, 9 Mar 2022 15:03:58 -0600 Subject: [PATCH 2/3] make style updateS --- src/sparseml/transformers/sparsification/trainer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index a385dd56538..f333f305410 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -202,9 +202,13 @@ def create_optimizer(self): if not self.manager: return - num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else self.args._n_gpu + num_devices = ( + torch.distributed.get_world_size() + if torch.distributed.is_initialized() + else self.args._n_gpu + ) if num_devices < 1: - num_devices = 1 + num_devices = 1 total_batch_size = ( self.args.per_device_train_batch_size * num_devices @@ -271,8 +275,7 @@ def create_scheduler(self, num_training_steps: int): # allow SparseML to manage LR and set a dummy scheduler self.lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR( - self.optimizer, - lambda _: 1.0, + self.optimizer, lambda _: 1.0, ) _LOGGER.warning("Overrode the lr_scheduler from SparseML recipe") From 28ad3d452998ffbc3117ff589a13c8c90dd4ff2e Mon Sep 17 00:00:00 2001 From: Daniel Campos Date: Wed, 9 Mar 2022 15:14:11 -0600 Subject: [PATCH 3/3] make style updateS --- src/sparseml/transformers/sparsification/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index f333f305410..e959474432a 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -275,7 +275,8 @@ def create_scheduler(self, num_training_steps: int): # allow SparseML to manage LR and set a dummy scheduler self.lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR( - self.optimizer, lambda _: 1.0, + self.optimizer, + lambda _: 1.0, ) _LOGGER.warning("Overrode the lr_scheduler from SparseML recipe")