Skip to content

Commit

Permalink
try fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Sep 24, 2021
1 parent 0f14f3b commit c63c00c
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
self.autoreport = autoreport
self.autoreport_dir = autoreport_dir
self.poptorch_models = {}
self._original_accumulate_grad_batches: Optional[int] = None
self._training_opts = training_opts
self._inference_opts = inference_opts

Expand Down Expand Up @@ -149,7 +150,7 @@ def _create_opts(self, training: bool) -> "poptorch.Options":
opts = poptorch.Options()
opts.deviceIterations(self.device_iterations)
opts.replicationFactor(self.replication_factor)
gradient_accumulation = self.lightning_module.trainer.accumulate_grad_batches if training else 1
gradient_accumulation = self.accumulate_grad_batches if training else 1
opts.Training.gradientAccumulation(gradient_accumulation)

if os.environ.get("PL_GLOBAL_SEED"):
Expand Down Expand Up @@ -184,22 +185,27 @@ def _convert_to_poptorch_loader(
dataloader = poptorch.DataLoader(**dl_kwargs, options=opts)
return dataloader

@property
def accumulate_grad_batches(self) -> int:
return self._original_accumulate_grad_batches

def _handle_gradient_accumulation_steps(self) -> None:
"""Override the trainer.accumulation_scheduler to act as ``accumulate_grad_batches=1`` if gradient
accumulation has been set.
``optimizer_step`` will be called on every batch, and the IPU will handle grad accumulation internally.
"""
accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler

if accumulation_scheduler.epochs != [0]:
accumulate_grad_batches = self.lightning_module.trainer.accumulate_grad_batches
if not isinstance(accumulate_grad_batches, int):
raise MisconfigurationException(
"IPUs currently does not support different `accumulate_grad_batches` at different epoch."
"IPUs currently only support `Trainer.accumulate_grad_batches` being an integer."
f" Received {accumulate_grad_batches}"
)

# TODO(@tchaton): Add support for accumulate_grad_batches being a dictionary
self.lightning_module.trainer.accumulation_scheduler = GradientAccumulationScheduler({0: 1})
# accumulation_scheduler.scheduling.update({0: 1})
# save the original value which will be used to update the global step progress
self._original_accumulate_grad_batches = accumulate_grad_batches
if accumulate_grad_batches > 1:
# TODO(@tchaton): Add support for accumulate_grad_batches being a dictionary
self.lightning_module.trainer.accumulation_scheduler = GradientAccumulationScheduler({0: 1})

@property
def _n_replicate(self):
Expand Down

0 comments on commit c63c00c

Please sign in to comment.