Skip to content

Commit

Permalink
Sharded Plugin 3/n: Expose step input to DDP plugin (#4686)
Browse files Browse the repository at this point in the history
* Allow ddp plugin to move the input to a different device if needed

* Swapped name to on_before_forward to align with hooks in the future

* Update pytorch_lightning/plugins/ddp_plugin.py

Co-authored-by: Jirka Borovec <[email protected]>

* Pass variable arg type to hook, add example

* Remove blank space (pep check)

* Added blank line

Co-authored-by: William Falcon <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
3 people authored and rohitgr7 committed Nov 21, 2020
1 parent f38af28 commit d28ee2f
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 40 deletions.
18 changes: 10 additions & 8 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,23 @@ def train(self):
return self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model)

def training_step(self, args):
return self._step(args)

def validation_step(self, args):
return self._step(args)

def test_step(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: Optional[str] = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()
Expand Down
18 changes: 10 additions & 8 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,23 @@ def train(self):
return results

def training_step(self, args):
return self._step(args)

def validation_step(self, args):
return self._step(args)

def test_step(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: Optional[str] = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()
Expand Down
18 changes: 10 additions & 8 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,21 +156,23 @@ def ddp_train(self, process_idx, mp_queue, model):
torch.cuda.empty_cache()

def training_step(self, args):
return self._step(args)

def validation_step(self, args):
return self._step(args)

def test_step(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: Optional[str] = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()
Expand Down
18 changes: 10 additions & 8 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,21 +77,23 @@ def get_device_ids(self):
return device_ids

def training_step(self, args):
return self._step(args)

def validation_step(self, args):
return self._step(args)

def test_step(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: Optional[str] = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()
Expand Down
18 changes: 10 additions & 8 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,23 @@ def get_device_ids(self):
return device_ids

def training_step(self, args):
return self._step(args)

def validation_step(self, args):
return self._step(args)

def test_step(self, args):
return self._step(args)

def _step(self, args):
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: Optional[str] = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()
Expand Down
18 changes: 18 additions & 0 deletions pytorch_lightning/plugins/ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,21 @@ def configure_ddp(self, model, device_ids):
**self._ddp_kwargs,
)
return model

def on_before_forward(self, model: LightningModule, *args):
"""
Override to handle custom input to device logic. For DDP, no logic is required as this is handled internally
within the DDP wrapper.
Example::
def on_before_forward(self, model, *args):
batch, batch_idx = args
return batch.to(model.device)
Args:
args: Inputs to the model.
model: Model to train.
Returns: args moved to correct device if needed.
"""
return args

0 comments on commit d28ee2f

Please sign in to comment.