diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ca7be742c1fe..5fab74f8b2135 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -241,6 +241,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug in the binary search mode of auto batch size scaling where exception was thrown if the first trainer run resulted in OOM ([#8954](https://github.com/PyTorchLightning/pytorch-lightning/pull/8954)) +- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142)) + + - Fixed not setting a default value for `max_epochs` if `max_time` was specified on the `Trainer` constructor ([#9072](https://github.com/PyTorchLightning/pytorch-lightning/pull/9072)) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 77079e6397f6f..38d09137b33ad 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -45,22 +45,33 @@ class MetricSource(LightningEnum): @dataclass class _Sync: fn: Optional[Callable] = None - should: bool = False + _should: bool = False rank_zero_only: bool = False op: Optional[str] = None group: Optional[Any] = None def __post_init__(self) -> None: - if self.fn is None: - self.fn = self.no_op + self._generate_sync_fn() + + @property + def should(self) -> bool: + return self._should + + @should.setter + def should(self, should: bool) -> None: + self._should = should + # `self._fn` needs to be re-generated. + self._generate_sync_fn() + + def _generate_sync_fn(self) -> None: + """Used to compute the syncing function and cache it.""" + fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn + # save the function as `_fn` as the meta are being re-created and the object references need to match. + self._fn = partial(fn, reduce_op=self.op, group=self.group) @property def __call__(self) -> Any: - return ( - partial(self.fn, reduce_op=self.op, group=self.group) - if self.should and not self.rank_zero_only - else self.no_op - ) + return self._fn @staticmethod def no_op(value: Any, *_, **__) -> Any: @@ -75,31 +86,28 @@ class _Metadata: logger: bool = True on_step: bool = False on_epoch: bool = True - _reduce_fx: Callable = torch.mean + reduce_fx: Callable = torch.mean enable_graph: bool = False dataloader_idx: Optional[int] = None metric_attribute: Optional[str] = None _sync: Optional[_Sync] = None - @property - def reduce_fx(self) -> Callable: - return self._reduce_fx + def __post_init__(self) -> None: + self._parse_reduce_fx() - @reduce_fx.setter - def reduce_fx(self, reduce_fx: Union[str, Callable]) -> None: + def _parse_reduce_fx(self) -> None: error = ( "Only `self.log(..., reduce_fx={min,max,mean,sum})` are currently supported." " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`." - f" Found: {reduce_fx}" + f" Found: {self.reduce_fx}" ) - self._reduce_fx = reduce_fx - if isinstance(reduce_fx, str): - reduce_fx = reduce_fx.lower() + if isinstance(self.reduce_fx, str): + reduce_fx = self.reduce_fx.lower() if reduce_fx == "avg": reduce_fx = "mean" if reduce_fx not in ("min", "max", "mean", "sum"): raise MisconfigurationException(error) - self._reduce_fx = getattr(torch, reduce_fx) + self.reduce_fx = getattr(torch, reduce_fx) elif self.is_custom_reduction: raise MisconfigurationException(error) @@ -178,11 +186,11 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: if self.is_tensor: value = value.float() - self._forward_cache = value # performance: no need to accumulate on values only logged on_step if self.meta.on_step and not self.meta.on_epoch: - self.value = self.meta.sync(value) + self._forward_cache = self.value = self.meta.sync(value) return + self._forward_cache = value # perform accumulation with reduction if self.meta.is_mean_reduction: self.value += value.mean() * batch_size @@ -201,8 +209,7 @@ def compute(self) -> torch.Tensor: if self.meta.is_mean_reduction: cumulated_batch_size = self.meta.sync(self.cumulated_batch_size) return value / cumulated_batch_size - elif self.meta.is_max_reduction or self.meta.is_min_reduction or self.meta.is_sum_reduction: - return value + return value return self.value.compute() def reset(self) -> None: @@ -449,12 +456,12 @@ def log( logger=logger, on_step=on_step, on_epoch=on_epoch, + reduce_fx=reduce_fx, enable_graph=enable_graph, dataloader_idx=dataloader_idx, metric_attribute=metric_attribute, ) - meta.reduce_fx = reduce_fx - meta.sync = _Sync(should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only) + meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only) # register logged value if it doesn't exist if key not in self: @@ -680,6 +687,8 @@ def load_state_dict( if not metrics: return + + # iterate through result metrics and re-attached Metric references on reload. result_metrics = self.result_metrics for metric_attribute, metric in metrics.items(): for result_metric in result_metrics: diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 4f254b6824489..73de501fc9911 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -181,10 +181,14 @@ def sync_ddp( if group is None: group = torch.distributed.group.WORLD - op = reduce_op if isinstance(reduce_op, ReduceOp) else ReduceOp.SUM - - if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): - divide_by_world_size = True + if isinstance(reduce_op, str): + if reduce_op.lower() in ("avg", "mean"): + op = ReduceOp.SUM + divide_by_world_size = True + else: + op = getattr(ReduceOp, reduce_op.upper()) + else: + op = reduce_op # sync all processes before reduction torch.distributed.barrier(group=group) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 7c6a985d09f33..dd03407080953 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -27,7 +27,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection +from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_7 from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -336,7 +336,7 @@ def on_save_checkpoint(self, checkpoint) -> None: # default sync fn new_results = ResultCollection(False, device) new_results.load_state_dict(state_dict, map_location="cpu") - assert new_results["validation_step.v"].meta.sync.fn == _Sync.no_op + assert new_results["validation_step.v"].meta.sync.fn is None # check map location assert new_results["validation_step.v"].value.device.type == "cpu" diff --git a/tests/core/test_results.py b/tests/core/test_results.py index bc3a35e95c21c..9d164b989f434 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -33,7 +33,7 @@ def _setup_ddp(rank, worldsize): def _ddp_test_fn(rank, worldsize): _setup_ddp(rank, worldsize) tensor = torch.tensor([1.0]) - sync = _Sync(sync_ddp_if_available, should=True, op="SUM") + sync = _Sync(sync_ddp_if_available, _should=True, op="SUM") actual = sync(tensor) assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors" diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 308cad8fcd632..385870fedd890 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -359,37 +359,74 @@ def get_expected(on_epoch, values): assert is_included if should_include else not is_included -@pytest.mark.parametrize("gpus", [None, pytest.param(1, marks=RunIf(min_gpus=1))]) +class LoggingSyncDistModel(BoringModel): + def __init__(self, fake_result): + super().__init__() + self.fake_result = fake_result + + @property + def rank(self) -> int: + return self.trainer.global_rank + + def training_step(self, batch, batch_idx): + value = self.fake_result + self.rank + self.log("foo", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum") + self.log("foo_2", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum") + self.log("foo_3", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean") + self.log("foo_4", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean") + self.log("foo_5", batch_idx + self.rank, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="max") + + self.log("foo_6", value, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") + self.log("foo_7", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") + self.log("foo_8", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean") + self.log("foo_9", value, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean") + self.log("foo_10", batch_idx, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max") + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + self.log("bar", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") + self.log("bar_2", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean") + self.log("bar_3", batch_idx + self.rank, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max") + return super().validation_step(batch, batch_idx) + + +@pytest.mark.parametrize( + "gpus", [None, pytest.param(1, marks=RunIf(min_gpus=1)), pytest.param(2, marks=RunIf(min_gpus=2))] +) def test_logging_sync_dist_true(tmpdir, gpus): """ Tests to ensure that the sync_dist flag works (should just return the original value) """ fake_result = 1 - - class TestModel(BoringModel): - def training_step(self, batch, batch_idx): - self.log("foo", fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") - self.log("foo_2", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") - return super().training_step(batch, batch_idx) - - def validation_step(self, batch, batch_idx): - self.log("bar", fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") - return super().validation_step(batch, batch_idx) - - model = TestModel() + model = LoggingSyncDistModel(fake_result) trainer = Trainer( + max_epochs=1, default_root_dir=tmpdir, - limit_train_batches=1, - limit_val_batches=1, - max_epochs=2, + limit_train_batches=3, + limit_val_batches=3, weights_summary=None, gpus=gpus, ) trainer.fit(model) - assert trainer.logged_metrics["foo"] == fake_result - assert trainer.logged_metrics["foo_2"] == 2 - assert trainer.logged_metrics["bar"] == fake_result + num_devices = 1 if gpus is None else gpus + use_multiple_devices = num_devices > 1 + total = fake_result * num_devices + 1 + + metrics = trainer.callback_metrics + assert metrics["foo"] == total if use_multiple_devices else fake_result + assert metrics["foo_2"] == 2 * num_devices + assert metrics["foo_3"] == 2 + assert metrics["foo_4"] == total / num_devices if use_multiple_devices else 1 + assert metrics["foo_5"] == fake_result * 2 + 1 if use_multiple_devices else fake_result * 2 + assert metrics["foo_6"] == fake_result * 3 * 2 + 3 if use_multiple_devices else fake_result * 3 * 2 + assert metrics["foo_7"] == 2 * num_devices * 3 + assert metrics["foo_8"] == 2 + assert metrics["foo_9"] == (fake_result * 2 + 1) / num_devices if use_multiple_devices else fake_result + assert metrics["foo_10"] == 2 + assert metrics["bar"] == fake_result * 3 * num_devices + assert metrics["bar_2"] == fake_result + assert metrics["bar_3"] == 2 + int(use_multiple_devices) @RunIf(min_gpus=2, special=True)