Skip to content

Commit

Permalink
Revert base class change, enforce sync tensors across accelerators, a…
Browse files Browse the repository at this point in the history
…dded GPU test
  • Loading branch information
SeanNaren committed Nov 11, 2020
1 parent 6cd7281 commit 4252c32
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def sync_tensor(self,
Return:
reduced value
"""
return tensor
raise NotImplementedError()

def __getstate__(self):
return {
Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union, Any

import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -80,3 +82,9 @@ def test_step(self, args):
else:
output = self.trainer.model.test_step(*args)
return output

def sync_tensor(self,
tensor: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return tensor
9 changes: 8 additions & 1 deletion pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, Optional, Any

import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.distributed.dist import LightningDistributed

Expand Down Expand Up @@ -120,3 +121,9 @@ def to_device(self, batch):
# be referenced from and if there are multiple optimizers the batch will
# wind up copying it to the same device repeatedly.
return self.batch_to_device(batch, gpu_id)

def sync_tensor(self,
tensor: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return tensor
10 changes: 8 additions & 2 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
import io
import os
import re
from typing import Optional
from typing import Optional, Union, Any

import torch
import torch.multiprocessing as mp

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save
Expand Down Expand Up @@ -337,3 +337,9 @@ def broadcast(self, obj, src=0):
buffer = io.BytesIO(data.cpu().byte().numpy())
obj = torch.load(buffer)
return obj

def sync_tensor(self,
tensor: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return tensor
34 changes: 34 additions & 0 deletions tests/trainer/logging_tests/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,3 +714,37 @@ def validation_step(self, batch, batch_idx):

assert trainer.logged_metrics['foo'] == fake_result
assert trainer.logged_metrics['bar'] == fake_result


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_logging_sync_dist_true_gpu(tmpdir):
"""
Tests to ensure that the sync_dist flag works with GPU (should just return the original value)
"""
fake_result = 1

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
acc = self.step(batch[0])
self.log('foo', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum')
return acc

def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log('bar', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum')
return {"x": loss}

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
max_epochs=2,
gpus=1,
weights_summary=None,
)
trainer.fit(model)

assert trainer.logged_metrics['foo'] == fake_result
assert trainer.logged_metrics['bar'] == fake_result

0 comments on commit 4252c32

Please sign in to comment.