From 4252c32fa12f2b03366b3250d3604a5bd3ca815c Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 11 Nov 2020 20:29:13 +0000 Subject: [PATCH] Revert base class change, enforce sync tensors across accelerators, added GPU test --- pytorch_lightning/accelerators/accelerator.py | 2 +- .../accelerators/cpu_accelerator.py | 10 +++++- .../accelerators/gpu_accelerator.py | 9 ++++- .../accelerators/tpu_accelerator.py | 10 ++++-- .../test_train_loop_logging_1_0.py | 34 +++++++++++++++++++ 5 files changed, 60 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index b5bb4df068d5f..3b762e08ed5e6 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -232,7 +232,7 @@ def sync_tensor(self, Return: reduced value """ - return tensor + raise NotImplementedError() def __getstate__(self): return { diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py index 083b5193ff8f3..66f9e4f0201b2 100644 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ b/pytorch_lightning/accelerators/cpu_accelerator.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Union, Any + import torch -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.utilities import AMPType, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -80,3 +82,9 @@ def test_step(self, args): else: output = self.trainer.model.test_step(*args) return output + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return tensor diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py index e66f5bcb8b48c..1a52c4037c8d3 100644 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union, Optional, Any import torch -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.utilities import AMPType from pytorch_lightning.distributed.dist import LightningDistributed @@ -120,3 +121,9 @@ def to_device(self, batch): # be referenced from and if there are multiple optimizers the batch will # wind up copying it to the same device repeatedly. return self.batch_to_device(batch, gpu_id) + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return tensor diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 5f4e6cc22cacd..15386b133f8bd 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -14,13 +14,13 @@ import io import os import re -from typing import Optional +from typing import Optional, Union, Any import torch import torch.multiprocessing as mp from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save @@ -337,3 +337,9 @@ def broadcast(self, obj, src=0): buffer = io.BytesIO(data.cpu().byte().numpy()) obj = torch.load(buffer) return obj + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return tensor diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index cd08d8cb659b8..cd8afd268cba8 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -714,3 +714,37 @@ def validation_step(self, batch, batch_idx): assert trainer.logged_metrics['foo'] == fake_result assert trainer.logged_metrics['bar'] == fake_result + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_logging_sync_dist_true_gpu(tmpdir): + """ + Tests to ensure that the sync_dist flag works with GPU (should just return the original value) + """ + fake_result = 1 + + class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + acc = self.step(batch[0]) + self.log('foo', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') + return acc + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('bar', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum') + return {"x": loss} + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=2, + gpus=1, + weights_summary=None, + ) + trainer.fit(model) + + assert trainer.logged_metrics['foo'] == fake_result + assert trainer.logged_metrics['bar'] == fake_result