From 206a6604e96e57c3231937c04ad1e80a992a32fe Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sun, 15 Nov 2020 15:38:38 +0000 Subject: [PATCH 1/6] Allow ddp plugin to move the input to a different device if needed --- .../accelerators/ddp2_accelerator.py | 18 ++++++++++-------- .../accelerators/ddp_accelerator.py | 18 ++++++++++-------- .../accelerators/ddp_cpu_spawn_accelerator.py | 18 ++++++++++-------- .../accelerators/ddp_hpc_accelerator.py | 18 ++++++++++-------- .../accelerators/ddp_spawn_accelerator.py | 18 ++++++++++-------- pytorch_lightning/plugins/ddp_plugin.py | 11 +++++++++++ 6 files changed, 61 insertions(+), 40 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 2da9747a9be92..a6982a16f5c8a 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -61,6 +61,16 @@ def train(self): return self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model) def training_step(self, args): + return self._step(args) + + def validation_step(self, args): + return self._step(args) + + def test_step(self, args): + return self._step(args) + + def _step(self, args): + args = self.ddp_plugin.input_to_device(args, self.trainer.get_model()) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) @@ -68,14 +78,6 @@ def training_step(self, args): output = self.trainer.model(*args) return output - def validation_step(self, args): - output = self.training_step(args) - return output - - def test_step(self, args): - output = self.training_step(args) - return output - def barrier(self, name: Optional[str] = None): if torch_distrib.is_initialized(): torch_distrib.barrier() diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index f99cd1149e5ae..67451ac6ada6e 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -151,6 +151,16 @@ def train(self): return results def training_step(self, args): + return self._step(args) + + def validation_step(self, args): + return self._step(args) + + def test_step(self, args): + return self._step(args) + + def _step(self, args): + args = self.ddp_plugin.input_to_device(args, self.trainer.get_model()) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) @@ -158,14 +168,6 @@ def training_step(self, args): output = self.trainer.model(*args) return output - def validation_step(self, args): - output = self.training_step(args) - return output - - def test_step(self, args): - output = self.training_step(args) - return output - def barrier(self, name: Optional[str] = None): if torch_distrib.is_initialized(): torch_distrib.barrier() diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 221ed5769c35e..5071a67f66386 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -156,6 +156,16 @@ def ddp_train(self, process_idx, mp_queue, model): torch.cuda.empty_cache() def training_step(self, args): + return self._step(args) + + def validation_step(self, args): + return self._step(args) + + def test_step(self, args): + return self._step(args) + + def _step(self, args): + args = self.ddp_plugin.input_to_device(args, self.trainer.get_model()) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) @@ -163,14 +173,6 @@ def training_step(self, args): output = self.trainer.model(*args) return output - def validation_step(self, args): - output = self.training_step(args) - return output - - def test_step(self, args): - output = self.training_step(args) - return output - def barrier(self, name: Optional[str] = None): if torch_distrib.is_initialized(): torch_distrib.barrier() diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index b6d813f978943..e00a60b40d8d9 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -77,6 +77,16 @@ def get_device_ids(self): return device_ids def training_step(self, args): + return self._step(args) + + def validation_step(self, args): + return self._step(args) + + def test_step(self, args): + return self._step(args) + + def _step(self, args): + args = self.ddp_plugin.input_to_device(args, self.trainer.get_model()) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) @@ -84,14 +94,6 @@ def training_step(self, args): output = self.trainer.model(*args) return output - def validation_step(self, args): - output = self.training_step(args) - return output - - def test_step(self, args): - output = self.training_step(args) - return output - def barrier(self, name: Optional[str] = None): if torch_distrib.is_initialized(): torch_distrib.barrier() diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index a30d266ec1b2f..1e054b3cb1757 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -182,6 +182,16 @@ def get_device_ids(self): return device_ids def training_step(self, args): + return self._step(args) + + def validation_step(self, args): + return self._step(args) + + def test_step(self, args): + return self._step(args) + + def _step(self, args): + args = self.ddp_plugin.input_to_device(args, self.trainer.get_model()) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) @@ -189,14 +199,6 @@ def training_step(self, args): output = self.trainer.model(*args) return output - def validation_step(self, args): - output = self.training_step(args) - return output - - def test_step(self, args): - output = self.training_step(args) - return output - def barrier(self, name: Optional[str] = None): if torch_distrib.is_initialized(): torch_distrib.barrier() diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 4c4fdc8f0d368..bd957efb336c0 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -62,3 +62,14 @@ def configure_ddp(self, model, device_ids): **self._ddp_kwargs, ) return model + + def input_to_device(self, args: Any, model: LightningModule): + """ + Override to handle custom input to device logic. For DDP, no move is required as this is handled internally + within the wrapper. + Args: + args: Inputs to the model. + model: Model to train. + Returns: args moved to correct device if needed. + """ + return args From ca6d5365ba01868bf3f5c319355609865c708b97 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sun, 15 Nov 2020 17:11:36 +0000 Subject: [PATCH 2/6] Swapped name to on_before_forward to align with hooks in the future --- pytorch_lightning/accelerators/ddp2_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_hpc_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_spawn_accelerator.py | 2 +- pytorch_lightning/plugins/ddp_plugin.py | 6 +++--- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index a6982a16f5c8a..edaed03768a18 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -70,7 +70,7 @@ def test_step(self, args): return self._step(args) def _step(self, args): - args = self.ddp_plugin.input_to_device(args, self.trainer.get_model()) + args = self.ddp_plugin.on_before_forward(args, self.trainer.get_model()) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 67451ac6ada6e..58e3dfc82d4df 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -160,7 +160,7 @@ def test_step(self, args): return self._step(args) def _step(self, args): - args = self.ddp_plugin.input_to_device(args, self.trainer.get_model()) + args = self.ddp_plugin.on_before_forward(args, self.trainer.get_model()) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 5071a67f66386..bf0c1b116d557 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -165,7 +165,7 @@ def test_step(self, args): return self._step(args) def _step(self, args): - args = self.ddp_plugin.input_to_device(args, self.trainer.get_model()) + args = self.ddp_plugin.on_before_forward(args, self.trainer.get_model()) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index e00a60b40d8d9..09adbb9c3ec5e 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -86,7 +86,7 @@ def test_step(self, args): return self._step(args) def _step(self, args): - args = self.ddp_plugin.input_to_device(args, self.trainer.get_model()) + args = self.ddp_plugin.on_before_forward(args, self.trainer.get_model()) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index 1e054b3cb1757..ad7588b4440b3 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -191,7 +191,7 @@ def test_step(self, args): return self._step(args) def _step(self, args): - args = self.ddp_plugin.input_to_device(args, self.trainer.get_model()) + args = self.ddp_plugin.on_before_forward(args, self.trainer.get_model()) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index bd957efb336c0..c300adfa013fa 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -63,10 +63,10 @@ def configure_ddp(self, model, device_ids): ) return model - def input_to_device(self, args: Any, model: LightningModule): + def on_before_forward(self, args: Any, model: LightningModule): """ - Override to handle custom input to device logic. For DDP, no move is required as this is handled internally - within the wrapper. + Override to handle custom input to device logic. For DDP, no logic is required as this is handled internally + within the DDP wrapper. Args: args: Inputs to the model. model: Model to train. From c4c3e7a9e73a2783029fff4b02a114ddc6a3540c Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Wed, 18 Nov 2020 13:49:51 +0000 Subject: [PATCH 3/6] Update pytorch_lightning/plugins/ddp_plugin.py Co-authored-by: Jirka Borovec --- pytorch_lightning/plugins/ddp_plugin.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index c300adfa013fa..879b710405015 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -67,9 +67,11 @@ def on_before_forward(self, args: Any, model: LightningModule): """ Override to handle custom input to device logic. For DDP, no logic is required as this is handled internally within the DDP wrapper. + Args: args: Inputs to the model. model: Model to train. + Returns: args moved to correct device if needed. """ return args From b320682759dfd528376764a283fea957b11a0201 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 18 Nov 2020 14:07:39 +0000 Subject: [PATCH 4/6] Pass variable arg type to hook, add example --- pytorch_lightning/accelerators/ddp2_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_accelerator.py | 2 +- .../accelerators/ddp_cpu_spawn_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_hpc_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_spawn_accelerator.py | 2 +- pytorch_lightning/plugins/ddp_plugin.py | 8 +++++++- 6 files changed, 12 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index edaed03768a18..9fcbdd4668ee9 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -70,7 +70,7 @@ def test_step(self, args): return self._step(args) def _step(self, args): - args = self.ddp_plugin.on_before_forward(args, self.trainer.get_model()) + args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 58e3dfc82d4df..69d41cd024646 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -160,7 +160,7 @@ def test_step(self, args): return self._step(args) def _step(self, args): - args = self.ddp_plugin.on_before_forward(args, self.trainer.get_model()) + args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index bf0c1b116d557..2a090a72e2b5a 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -165,7 +165,7 @@ def test_step(self, args): return self._step(args) def _step(self, args): - args = self.ddp_plugin.on_before_forward(args, self.trainer.get_model()) + args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 09adbb9c3ec5e..2ff9c2b7ddaae 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -86,7 +86,7 @@ def test_step(self, args): return self._step(args) def _step(self, args): - args = self.ddp_plugin.on_before_forward(args, self.trainer.get_model()) + args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index ad7588b4440b3..eac51393a5f2e 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -191,7 +191,7 @@ def test_step(self, args): return self._step(args) def _step(self, args): - args = self.ddp_plugin.on_before_forward(args, self.trainer.get_model()) + args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 879b710405015..fea309a9ceb38 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -63,10 +63,16 @@ def configure_ddp(self, model, device_ids): ) return model - def on_before_forward(self, args: Any, model: LightningModule): + def on_before_forward(self, model: LightningModule, *args): """ Override to handle custom input to device logic. For DDP, no logic is required as this is handled internally within the DDP wrapper. + + Example:: + + def on_before_forward(self, model, *args): + batch, batch_idx = args + return batch.to(model.device) Args: args: Inputs to the model. From 03b7bbd75b8be4323759117cb4795a42702d322d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 18 Nov 2020 14:09:36 +0000 Subject: [PATCH 5/6] Remove blank space (pep check) --- pytorch_lightning/plugins/ddp_plugin.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index fea309a9ceb38..33b0f672ec6bd 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -73,11 +73,9 @@ def on_before_forward(self, model: LightningModule, *args): def on_before_forward(self, model, *args): batch, batch_idx = args return batch.to(model.device) - Args: args: Inputs to the model. model: Model to train. - Returns: args moved to correct device if needed. """ return args From 7ade0c5c1122eb3244123b90286a9a1fde065f00 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 18 Nov 2020 14:14:31 +0000 Subject: [PATCH 6/6] Added blank line --- pytorch_lightning/plugins/ddp_plugin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 33b0f672ec6bd..4d73d4bdded7d 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -73,6 +73,7 @@ def on_before_forward(self, model: LightningModule, *args): def on_before_forward(self, model, *args): batch, batch_idx = args return batch.to(model.device) + Args: args: Inputs to the model. model: Model to train.