diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 2da9747a9be92..9fcbdd4668ee9 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -61,6 +61,16 @@ def train(self): return self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model) def training_step(self, args): + return self._step(args) + + def validation_step(self, args): + return self._step(args) + + def test_step(self, args): + return self._step(args) + + def _step(self, args): + args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) @@ -68,14 +78,6 @@ def training_step(self, args): output = self.trainer.model(*args) return output - def validation_step(self, args): - output = self.training_step(args) - return output - - def test_step(self, args): - output = self.training_step(args) - return output - def barrier(self, name: Optional[str] = None): if torch_distrib.is_initialized(): torch_distrib.barrier() diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index f99cd1149e5ae..69d41cd024646 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -151,6 +151,16 @@ def train(self): return results def training_step(self, args): + return self._step(args) + + def validation_step(self, args): + return self._step(args) + + def test_step(self, args): + return self._step(args) + + def _step(self, args): + args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) @@ -158,14 +168,6 @@ def training_step(self, args): output = self.trainer.model(*args) return output - def validation_step(self, args): - output = self.training_step(args) - return output - - def test_step(self, args): - output = self.training_step(args) - return output - def barrier(self, name: Optional[str] = None): if torch_distrib.is_initialized(): torch_distrib.barrier() diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 221ed5769c35e..2a090a72e2b5a 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -156,6 +156,16 @@ def ddp_train(self, process_idx, mp_queue, model): torch.cuda.empty_cache() def training_step(self, args): + return self._step(args) + + def validation_step(self, args): + return self._step(args) + + def test_step(self, args): + return self._step(args) + + def _step(self, args): + args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) @@ -163,14 +173,6 @@ def training_step(self, args): output = self.trainer.model(*args) return output - def validation_step(self, args): - output = self.training_step(args) - return output - - def test_step(self, args): - output = self.training_step(args) - return output - def barrier(self, name: Optional[str] = None): if torch_distrib.is_initialized(): torch_distrib.barrier() diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index b6d813f978943..2ff9c2b7ddaae 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -77,6 +77,16 @@ def get_device_ids(self): return device_ids def training_step(self, args): + return self._step(args) + + def validation_step(self, args): + return self._step(args) + + def test_step(self, args): + return self._step(args) + + def _step(self, args): + args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) @@ -84,14 +94,6 @@ def training_step(self, args): output = self.trainer.model(*args) return output - def validation_step(self, args): - output = self.training_step(args) - return output - - def test_step(self, args): - output = self.training_step(args) - return output - def barrier(self, name: Optional[str] = None): if torch_distrib.is_initialized(): torch_distrib.barrier() diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index a30d266ec1b2f..eac51393a5f2e 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -182,6 +182,16 @@ def get_device_ids(self): return device_ids def training_step(self, args): + return self._step(args) + + def validation_step(self, args): + return self._step(args) + + def test_step(self, args): + return self._step(args) + + def _step(self, args): + args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) @@ -189,14 +199,6 @@ def training_step(self, args): output = self.trainer.model(*args) return output - def validation_step(self, args): - output = self.training_step(args) - return output - - def test_step(self, args): - output = self.training_step(args) - return output - def barrier(self, name: Optional[str] = None): if torch_distrib.is_initialized(): torch_distrib.barrier() diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 4c4fdc8f0d368..4d73d4bdded7d 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -62,3 +62,21 @@ def configure_ddp(self, model, device_ids): **self._ddp_kwargs, ) return model + + def on_before_forward(self, model: LightningModule, *args): + """ + Override to handle custom input to device logic. For DDP, no logic is required as this is handled internally + within the DDP wrapper. + + Example:: + + def on_before_forward(self, model, *args): + batch, batch_idx = args + return batch.to(model.device) + + Args: + args: Inputs to the model. + model: Model to train. + Returns: args moved to correct device if needed. + """ + return args