From dd80bfa3319428b1cdbafab40edb1a519d429ee2 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Thu, 26 Aug 2021 13:15:44 -0400 Subject: [PATCH 01/27] improve test --- test.py | 0 .../logging_/test_train_loop_logging.py | 51 ++++++++++++------- 2 files changed, 34 insertions(+), 17 deletions(-) create mode 100644 test.py diff --git a/test.py b/test.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 308cad8fcd632..9a7b970a5df35 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -359,37 +359,54 @@ def get_expected(on_epoch, values): assert is_included if should_include else not is_included -@pytest.mark.parametrize("gpus", [None, pytest.param(1, marks=RunIf(min_gpus=1))]) +class LoggingSyncDistModel(BoringModel): + def __init__(self, fake_result): + super().__init__() + self.fake_result = fake_result + + def training_step(self, batch, batch_idx): + self.log("foo", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") + self.log("foo_2", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") + self.log("foo_3", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean") + self.log("foo_4", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean") + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + self.log("bar", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") + self.log("bar_2", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean") + return super().validation_step(batch, batch_idx) + + +@pytest.mark.parametrize("gpus", [ + None, + pytest.param(1, marks=RunIf(min_gpus=1)), + pytest.param(2, marks=RunIf(min_gpus=2)) +]) def test_logging_sync_dist_true(tmpdir, gpus): """ Tests to ensure that the sync_dist flag works (should just return the original value) """ fake_result = 1 - - class TestModel(BoringModel): - def training_step(self, batch, batch_idx): - self.log("foo", fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") - self.log("foo_2", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") - return super().training_step(batch, batch_idx) - - def validation_step(self, batch, batch_idx): - self.log("bar", fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") - return super().validation_step(batch, batch_idx) - - model = TestModel() + model = LoggingSyncDistModel(fake_result) trainer = Trainer( + max_epochs=1, default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=1, - max_epochs=2, weights_summary=None, gpus=gpus, ) trainer.fit(model) - assert trainer.logged_metrics["foo"] == fake_result - assert trainer.logged_metrics["foo_2"] == 2 - assert trainer.logged_metrics["bar"] == fake_result + num_devices = 1 if gpus is None else gpus + + assert trainer.callback_metrics["foo"] == fake_result * num_devices + assert trainer.callback_metrics["foo_2"] == 2 * num_devices + assert trainer.callback_metrics["foo_3"] == 2 + assert trainer.callback_metrics["foo_4"] == fake_result + + assert trainer.callback_metrics["bar"] == fake_result * num_devices + assert trainer.callback_metrics["bar_2"] == fake_result @RunIf(min_gpus=2, special=True) From 2d7981fcdfb99cff3871f430374e28f30f230889 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Thu, 26 Aug 2021 14:19:50 -0400 Subject: [PATCH 02/27] resolve bug --- pytorch_lightning/core/lightning.py | 7 +++- .../connectors/logger_connector/result.py | 20 +++++++--- .../logging_/test_train_loop_logging.py | 39 ++++++++++++++----- 3 files changed, 49 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 096333388c3b1..f634ea81370e1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -465,6 +465,11 @@ def log( "With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided." ) + if reduce_fx in ("max", "min"): + sync_dist_fn = self.trainer.training_type_plugin.all_gather + else: + sync_dist_fn = self.trainer.training_type_plugin.reduce or sync_ddp + results.log( self._current_fx_name, name, @@ -478,7 +483,7 @@ def log( dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), batch_size=batch_size, sync_dist=sync_dist and distributed_available(), - sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp, + sync_dist_fn=sync_dist_fn, sync_dist_group=sync_dist_group, metric_attribute=metric_attribute, rank_zero_only=rank_zero_only, diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 77079e6397f6f..f7443d46645db 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -13,12 +13,12 @@ # limitations under the License. from collections.abc import Generator from dataclasses import asdict, dataclass, replace -from functools import partial, wraps +from functools import lru_cache, partial, wraps from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import torch from torchmetrics import Metric - +import inspect from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device @@ -56,8 +56,11 @@ def __post_init__(self) -> None: @property def __call__(self) -> Any: + kwargs = dict(group=self.group) + if "reduce_op" in inspect.signature(self.fn).parameters: + kwargs["reduce_op"] = self.op return ( - partial(self.fn, reduce_op=self.op, group=self.group) + partial(self.fn, **kwargs) if self.should and not self.rank_zero_only else self.no_op ) @@ -124,7 +127,7 @@ def forked_name(self, on_step: bool) -> str: @property def is_mean_reduction(self) -> bool: - return self.reduce_fx is torch.mean + return self.reduce_fx is (torch.mean) @property def is_sum_reduction(self) -> bool: @@ -181,7 +184,10 @@ def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: self._forward_cache = value # performance: no need to accumulate on values only logged on_step if self.meta.on_step and not self.meta.on_epoch: - self.value = self.meta.sync(value) + value = self.meta.sync(value) + if self.meta.is_max_reduction or self.meta.is_min_reduction: + value = self.meta.reduce_fx(value) + self._forward_cache = self.value = value return # perform accumulation with reduction if self.meta.is_mean_reduction: @@ -202,7 +208,7 @@ def compute(self) -> torch.Tensor: cumulated_batch_size = self.meta.sync(self.cumulated_batch_size) return value / cumulated_batch_size elif self.meta.is_max_reduction or self.meta.is_min_reduction or self.meta.is_sum_reduction: - return value + return self.meta.reduce_fx(value) return self.value.compute() def reset(self) -> None: @@ -562,6 +568,8 @@ def any_tensor(_): if result_metric.meta.prog_bar: metrics[MetricSource.PBAR][forked_name] = metrics_to_scalars(value) + print(metrics) + return metrics def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> None: diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 9a7b970a5df35..39cb53b0b756b 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -18,7 +18,7 @@ import collections import itertools from re import escape - +import os import numpy as np import pytest import torch @@ -365,15 +365,24 @@ def __init__(self, fake_result): self.fake_result = fake_result def training_step(self, batch, batch_idx): - self.log("foo", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") - self.log("foo_2", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") - self.log("foo_3", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean") - self.log("foo_4", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean") + value = self.fake_result + self.trainer.global_rank + self.log("foo", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum") + self.log("foo_2", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum") + self.log("foo_3", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean") + self.log("foo_4", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean") + self.log("foo_5", batch_idx + self.trainer.global_rank, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="max") + + self.log("foo_6", value, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") + self.log("foo_7", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") + self.log("foo_8", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean") + self.log("foo_9", value, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean") + self.log("foo_10", batch_idx, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max") return super().training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): self.log("bar", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") self.log("bar_2", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean") + self.log("bar_3", batch_idx + self.trainer.global_rank, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max") return super().validation_step(batch, batch_idx) @@ -391,22 +400,32 @@ def test_logging_sync_dist_true(tmpdir, gpus): trainer = Trainer( max_epochs=1, default_root_dir=tmpdir, - limit_train_batches=1, - limit_val_batches=1, + limit_train_batches=3, + limit_val_batches=3, weights_summary=None, gpus=gpus, ) trainer.fit(model) num_devices = 1 if gpus is None else gpus + use_multiple_devices = num_devices > 1 + total = fake_result * num_devices + 1 - assert trainer.callback_metrics["foo"] == fake_result * num_devices + assert trainer.callback_metrics["foo"] == total if use_multiple_devices else fake_result assert trainer.callback_metrics["foo_2"] == 2 * num_devices assert trainer.callback_metrics["foo_3"] == 2 - assert trainer.callback_metrics["foo_4"] == fake_result + assert trainer.callback_metrics["foo_4"] == total / num_devices if use_multiple_devices else 1 + assert trainer.callback_metrics["foo_5"] == fake_result * 2 + 1 if use_multiple_devices else fake_result * 2 + + assert trainer.callback_metrics["foo_6"] == fake_result * 3 * 2 + 3 if use_multiple_devices else fake_result * 3 * 2 + assert trainer.callback_metrics["foo_7"] == 2 * num_devices * 3 + assert trainer.callback_metrics["foo_8"] == 2 + assert trainer.callback_metrics["foo_9"] == (fake_result * 2 + 1) / num_devices if use_multiple_devices else fake_result + assert trainer.callback_metrics["foo_10"] == 2 - assert trainer.callback_metrics["bar"] == fake_result * num_devices + assert trainer.callback_metrics["bar"] == fake_result * 3 * num_devices assert trainer.callback_metrics["bar_2"] == fake_result + assert trainer.callback_metrics["bar_3"] == 2 + int(use_multiple_devices) @RunIf(min_gpus=2, special=True) From ba50ead441be4aaec563823f9effa2a0d616ad0e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Aug 2021 18:21:48 +0000 Subject: [PATCH 03/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../connectors/logger_connector/result.py | 9 ++---- .../logging_/test_train_loop_logging.py | 29 ++++++++++++------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index f7443d46645db..131be7900d3c0 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect from collections.abc import Generator from dataclasses import asdict, dataclass, replace from functools import lru_cache, partial, wraps @@ -18,7 +19,7 @@ import torch from torchmetrics import Metric -import inspect + from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device @@ -59,11 +60,7 @@ def __call__(self) -> Any: kwargs = dict(group=self.group) if "reduce_op" in inspect.signature(self.fn).parameters: kwargs["reduce_op"] = self.op - return ( - partial(self.fn, **kwargs) - if self.should and not self.rank_zero_only - else self.no_op - ) + return partial(self.fn, **kwargs) if self.should and not self.rank_zero_only else self.no_op @staticmethod def no_op(value: Any, *_, **__) -> Any: diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 39cb53b0b756b..f8d2190dd740e 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -17,8 +17,9 @@ import collections import itertools -from re import escape import os +from re import escape + import numpy as np import pytest import torch @@ -363,14 +364,16 @@ class LoggingSyncDistModel(BoringModel): def __init__(self, fake_result): super().__init__() self.fake_result = fake_result - + def training_step(self, batch, batch_idx): value = self.fake_result + self.trainer.global_rank self.log("foo", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum") self.log("foo_2", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum") self.log("foo_3", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean") self.log("foo_4", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean") - self.log("foo_5", batch_idx + self.trainer.global_rank, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="max") + self.log( + "foo_5", batch_idx + self.trainer.global_rank, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="max" + ) self.log("foo_6", value, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") self.log("foo_7", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") @@ -382,15 +385,15 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): self.log("bar", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") self.log("bar_2", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean") - self.log("bar_3", batch_idx + self.trainer.global_rank, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max") + self.log( + "bar_3", batch_idx + self.trainer.global_rank, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max" + ) return super().validation_step(batch, batch_idx) -@pytest.mark.parametrize("gpus", [ - None, - pytest.param(1, marks=RunIf(min_gpus=1)), - pytest.param(2, marks=RunIf(min_gpus=2)) -]) +@pytest.mark.parametrize( + "gpus", [None, pytest.param(1, marks=RunIf(min_gpus=1)), pytest.param(2, marks=RunIf(min_gpus=2))] +) def test_logging_sync_dist_true(tmpdir, gpus): """ Tests to ensure that the sync_dist flag works (should just return the original value) @@ -420,9 +423,13 @@ def test_logging_sync_dist_true(tmpdir, gpus): assert trainer.callback_metrics["foo_6"] == fake_result * 3 * 2 + 3 if use_multiple_devices else fake_result * 3 * 2 assert trainer.callback_metrics["foo_7"] == 2 * num_devices * 3 assert trainer.callback_metrics["foo_8"] == 2 - assert trainer.callback_metrics["foo_9"] == (fake_result * 2 + 1) / num_devices if use_multiple_devices else fake_result + assert ( + trainer.callback_metrics["foo_9"] == (fake_result * 2 + 1) / num_devices + if use_multiple_devices + else fake_result + ) assert trainer.callback_metrics["foo_10"] == 2 - + assert trainer.callback_metrics["bar"] == fake_result * 3 * num_devices assert trainer.callback_metrics["bar_2"] == fake_result assert trainer.callback_metrics["bar_3"] == 2 + int(use_multiple_devices) From 97ee3f84b21deee9600a1c40c6da953a2e09e38e Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Thu, 26 Aug 2021 14:24:04 -0400 Subject: [PATCH 04/27] update changelog --- CHANGELOG.md | 3 +++ .../trainer/connectors/logger_connector/result.py | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 36d90ae213fdb..272bfca4f20f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -233,6 +233,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug in the binary search mode of auto batch size scaling where exception was thrown if the first trainer run resulted in OOM ([#8954](https://github.com/PyTorchLightning/pytorch-lightning/pull/8954)) +- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean, max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142)) + + ## [1.4.3] - 2021-08-17 - Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861)) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index f7443d46645db..a857c6f61db22 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -208,7 +208,9 @@ def compute(self) -> torch.Tensor: cumulated_batch_size = self.meta.sync(self.cumulated_batch_size) return value / cumulated_batch_size elif self.meta.is_max_reduction or self.meta.is_min_reduction or self.meta.is_sum_reduction: - return self.meta.reduce_fx(value) + if value.dim() > 0: + value = self.meta.reduce_fx(value) + return value return self.value.compute() def reset(self) -> None: From c8cd17e210efe7922736927b481a800191c33884 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Thu, 26 Aug 2021 14:25:33 -0400 Subject: [PATCH 05/27] remove test --- test.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test.py diff --git a/test.py b/test.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 From 226f9a4f68a520f5e09c1803112d7b2cea3871ec Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 26 Aug 2021 19:31:00 +0100 Subject: [PATCH 06/27] update --- .../connectors/logger_connector/result.py | 18 +++++++++++------- .../logging_/test_train_loop_logging.py | 1 - 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 131be7900d3c0..943a49a01c443 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -14,7 +14,7 @@ import inspect from collections.abc import Generator from dataclasses import asdict, dataclass, replace -from functools import lru_cache, partial, wraps +from functools import partial, wraps from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import torch @@ -52,15 +52,19 @@ class _Sync: group: Optional[Any] = None def __post_init__(self) -> None: - if self.fn is None: - self.fn = self.no_op + if self.fn: + kwargs = dict(group=self.group) + if "reduce_op" in inspect.signature(self.fn).parameters: + kwargs["reduce_op"] = self.op + self.fn = ( + partial(self.fn, **kwargs) + if self.fn is not None and self.should and not self.rank_zero_only + else self.no_op + ) @property def __call__(self) -> Any: - kwargs = dict(group=self.group) - if "reduce_op" in inspect.signature(self.fn).parameters: - kwargs["reduce_op"] = self.op - return partial(self.fn, **kwargs) if self.should and not self.rank_zero_only else self.no_op + return self.fn @staticmethod def no_op(value: Any, *_, **__) -> Any: diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index f8d2190dd740e..ad27c596a5df3 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -17,7 +17,6 @@ import collections import itertools -import os from re import escape import numpy as np From 54df136a0a7fb95b0f71932367ada6964279dae6 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Thu, 26 Aug 2021 14:42:27 -0400 Subject: [PATCH 07/27] improvement --- .../trainer/connectors/logger_connector/result.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 8181065887190..a3aa295b22624 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -56,15 +56,12 @@ def __post_init__(self) -> None: kwargs = dict(group=self.group) if "reduce_op" in inspect.signature(self.fn).parameters: kwargs["reduce_op"] = self.op - self.fn = ( - partial(self.fn, **kwargs) - if self.fn is not None and self.should and not self.rank_zero_only - else self.no_op - ) + + self.fn_call = partial(self.fn, **kwargs) if self.fn and self.should and not self.rank_zero_only else self.no_op @property def __call__(self) -> Any: - return self.fn + return self.fn_call @staticmethod def no_op(value: Any, *_, **__) -> Any: @@ -471,6 +468,7 @@ def log( # check the stored metadata and the current one match elif meta != self[key].meta: + print(meta, self[key].meta) raise MisconfigurationException( f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed" ) From 691872412d7b2dd51f275add597914dcc1d6ead5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 26 Aug 2021 20:00:41 +0100 Subject: [PATCH 08/27] update --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 8181065887190..3fb40e47c8da4 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -571,8 +571,6 @@ def any_tensor(_): if result_metric.meta.prog_bar: metrics[MetricSource.PBAR][forked_name] = metrics_to_scalars(value) - print(metrics) - return metrics def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> None: From ac2a13e2cc6f39531acc9261441b19bd93176372 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 26 Aug 2021 20:02:48 +0100 Subject: [PATCH 09/27] resolve tests --- .../trainer/connectors/logger_connector/result.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 2c1ff2c916397..f10b0c8e820b7 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -59,6 +59,9 @@ def __post_init__(self) -> None: self.fn_call = partial(self.fn, **kwargs) if self.fn and self.should and not self.rank_zero_only else self.no_op + if not self.fn: + self.fn = self.no_op + @property def __call__(self) -> Any: return self.fn_call From 4ed928e49c9a618d16e7960e1d400b44065f9a9e Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 27 Aug 2021 04:30:06 -0400 Subject: [PATCH 10/27] update on comments --- .../connectors/logger_connector/result.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index f10b0c8e820b7..18575433be3db 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -52,19 +52,18 @@ class _Sync: group: Optional[Any] = None def __post_init__(self) -> None: - if self.fn: - kwargs = dict(group=self.group) + if self.fn is None: + self.fn = self.no_op + if self.should and not self.rank_zero_only: + kwargs = {"group": self.group} if "reduce_op" in inspect.signature(self.fn).parameters: kwargs["reduce_op"] = self.op - - self.fn_call = partial(self.fn, **kwargs) if self.fn and self.should and not self.rank_zero_only else self.no_op - - if not self.fn: - self.fn = self.no_op - + self._fn = partial(self.fn, **kwargs) + else: + self._fn = self.no_op @property def __call__(self) -> Any: - return self.fn_call + return self._fn @staticmethod def no_op(value: Any, *_, **__) -> Any: @@ -471,7 +470,6 @@ def log( # check the stored metadata and the current one match elif meta != self[key].meta: - print(meta, self[key].meta) raise MisconfigurationException( f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed" ) From fdbd065171c9228a23f5b38cd1ac4538133c7ba6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Aug 2021 08:31:18 +0000 Subject: [PATCH 11/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 18575433be3db..bd35245e2f5e0 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -61,6 +61,7 @@ def __post_init__(self) -> None: self._fn = partial(self.fn, **kwargs) else: self._fn = self.no_op + @property def __call__(self) -> Any: return self._fn From 55238dee801bde41715ecdffbe8496729ad38bee Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 27 Aug 2021 05:29:52 -0400 Subject: [PATCH 12/27] resolve test --- .../connectors/logger_connector/result.py | 20 +++++++++++++++++-- tests/core/test_metric_result_integration.py | 1 + 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 18575433be3db..93581e11f10f4 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -45,6 +45,7 @@ class MetricSource(LightningEnum): @dataclass class _Sync: + fn: Optional[Callable] = None should: bool = False rank_zero_only: bool = False @@ -52,8 +53,20 @@ class _Sync: group: Optional[Any] = None def __post_init__(self) -> None: + self._generate_sync_fn() + + def set_should(self, should: bool) -> None: + self.should = should + # when should changes, the `sync fn` need to be re-generated. + self._generate_sync_fn() + + def _generate_sync_fn(self): + """Used to compute the syncing function and cache it.""" if self.fn is None: self.fn = self.no_op + + # save the function as `_fn` as the meta are being re-created + # and the object references need to match. if self.should and not self.rank_zero_only: kwargs = {"group": self.group} if "reduce_op" in inspect.signature(self.fn).parameters: @@ -61,6 +74,7 @@ def __post_init__(self) -> None: self._fn = partial(self.fn, **kwargs) else: self._fn = self.no_op + @property def __call__(self) -> Any: return self._fn @@ -508,9 +522,9 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten if not result_metric._computed: # always reduce on epoch end should = result_metric.meta.sync.should - result_metric.meta.sync.should = True + result_metric.meta.sync.set_should(True) result_metric.compute() - result_metric.meta.sync.should = should + result_metric.meta.sync.set_should(should) cache = result_metric._computed if cache is not None and not result_metric.meta.enable_graph: return cache.detach() @@ -688,6 +702,8 @@ def load_state_dict( if not metrics: return + + # iterate through result metrics and re-attached Metric references on reload. result_metrics = self.result_metrics for metric_attribute, metric in metrics.items(): for result_metric in result_metrics: diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 7c6a985d09f33..fda7ea813c6c4 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -449,6 +449,7 @@ def training_step(self, batch, batch_idx): def on_epoch_end(self) -> None: if self.trainer.fit_loop.restarting: total = sum(range(5)) * num_processes + print(self.results["training_step.tracking"].meta) metrics = self.results.metrics(on_step=False) assert self.results["training_step.tracking"].value == total assert metrics[MetricSource.CALLBACK]["tracking"] == self.dummy_metric.compute() == 2 From c87691fe8256d746aadc8ed815afe70fc4e2ed38 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Aug 2021 09:31:12 +0000 Subject: [PATCH 13/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../trainer/connectors/logger_connector/result.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 93581e11f10f4..a661cd3f32f29 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -45,7 +45,7 @@ class MetricSource(LightningEnum): @dataclass class _Sync: - + fn: Optional[Callable] = None should: bool = False rank_zero_only: bool = False @@ -57,7 +57,7 @@ def __post_init__(self) -> None: def set_should(self, should: bool) -> None: self.should = should - # when should changes, the `sync fn` need to be re-generated. + # when should changes, the `sync fn` need to be re-generated. self._generate_sync_fn() def _generate_sync_fn(self): @@ -65,7 +65,7 @@ def _generate_sync_fn(self): if self.fn is None: self.fn = self.no_op - # save the function as `_fn` as the meta are being re-created + # save the function as `_fn` as the meta are being re-created # and the object references need to match. if self.should and not self.rank_zero_only: kwargs = {"group": self.group} @@ -702,8 +702,8 @@ def load_state_dict( if not metrics: return - - # iterate through result metrics and re-attached Metric references on reload. + + # iterate through result metrics and re-attached Metric references on reload. result_metrics = self.result_metrics for metric_attribute, metric in metrics.items(): for result_metric in result_metrics: From 18f148f0d4e6f5823afc710edcd1904cfd1879c5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 27 Aug 2021 10:31:22 +0100 Subject: [PATCH 14/27] resolve typo --- .../trainer/connectors/logger_connector/result.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 93581e11f10f4..d22c3ecc272b0 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -45,7 +45,6 @@ class MetricSource(LightningEnum): @dataclass class _Sync: - fn: Optional[Callable] = None should: bool = False rank_zero_only: bool = False @@ -57,7 +56,7 @@ def __post_init__(self) -> None: def set_should(self, should: bool) -> None: self.should = should - # when should changes, the `sync fn` need to be re-generated. + # when should changes, the `sync fn` need to be re-generated. self._generate_sync_fn() def _generate_sync_fn(self): @@ -65,7 +64,7 @@ def _generate_sync_fn(self): if self.fn is None: self.fn = self.no_op - # save the function as `_fn` as the meta are being re-created + # save the function as `_fn` as the meta are being re-created # and the object references need to match. if self.should and not self.rank_zero_only: kwargs = {"group": self.group} @@ -702,8 +701,8 @@ def load_state_dict( if not metrics: return - - # iterate through result metrics and re-attached Metric references on reload. + + # iterate through result metrics and re-attached Metric references on reload. result_metrics = self.result_metrics for metric_attribute, metric in metrics.items(): for result_metric in result_metrics: From 662f720f8b5f9d5eb23d77e39b99591bf0d60dc1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 27 Aug 2021 10:34:40 +0100 Subject: [PATCH 15/27] update --- .../trainer/connectors/logger_connector/result.py | 2 +- tests/trainer/logging_/test_train_loop_logging.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index d22c3ecc272b0..7bb363391063c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -140,7 +140,7 @@ def forked_name(self, on_step: bool) -> str: @property def is_mean_reduction(self) -> bool: - return self.reduce_fx is (torch.mean) + return self.reduce_fx is torch.mean @property def is_sum_reduction(self) -> bool: diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index ad27c596a5df3..8d53fa623e109 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -364,15 +364,17 @@ def __init__(self, fake_result): super().__init__() self.fake_result = fake_result + @property + def rank(self) -> int: + return self.trainer.global_rank + def training_step(self, batch, batch_idx): - value = self.fake_result + self.trainer.global_rank + value = self.fake_result + self.rank self.log("foo", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum") self.log("foo_2", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum") self.log("foo_3", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean") self.log("foo_4", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean") - self.log( - "foo_5", batch_idx + self.trainer.global_rank, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="max" - ) + self.log("foo_5", batch_idx + self.rank, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="max") self.log("foo_6", value, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") self.log("foo_7", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") @@ -384,9 +386,7 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): self.log("bar", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum") self.log("bar_2", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean") - self.log( - "bar_3", batch_idx + self.trainer.global_rank, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max" - ) + self.log("bar_3", batch_idx + self.rank, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max") return super().validation_step(batch, batch_idx) From 7f77ba0712d376e91877bc351003d5ec8cddb830 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Fri, 27 Aug 2021 16:02:38 +0530 Subject: [PATCH 16/27] Update tests/core/test_metric_result_integration.py --- tests/core/test_metric_result_integration.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index fda7ea813c6c4..7c6a985d09f33 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -449,7 +449,6 @@ def training_step(self, batch, batch_idx): def on_epoch_end(self) -> None: if self.trainer.fit_loop.restarting: total = sum(range(5)) * num_processes - print(self.results["training_step.tracking"].meta) metrics = self.results.metrics(on_step=False) assert self.results["training_step.tracking"].value == total assert metrics[MetricSource.CALLBACK]["tracking"] == self.dummy_metric.compute() == 2 From b72571d781c6d42a1febd73e0c363e01bd494d7c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 27 Aug 2021 15:26:34 +0200 Subject: [PATCH 17/27] Refactor and simplify --- .../connectors/logger_connector/result.py | 40 ++++++++++--------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 7bb363391063c..e6142abeb4a33 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -46,7 +46,7 @@ class MetricSource(LightningEnum): @dataclass class _Sync: fn: Optional[Callable] = None - should: bool = False + _should: bool = False rank_zero_only: bool = False op: Optional[str] = None group: Optional[Any] = None @@ -54,9 +54,14 @@ class _Sync: def __post_init__(self) -> None: self._generate_sync_fn() - def set_should(self, should: bool) -> None: - self.should = should - # when should changes, the `sync fn` need to be re-generated. + @property + def should(self) -> bool: + return self._should + + @should.setter + def should(self, should: bool) -> None: + self._should = should + # `self._fn` needs to be re-generated. self._generate_sync_fn() def _generate_sync_fn(self): @@ -91,31 +96,28 @@ class _Metadata: logger: bool = True on_step: bool = False on_epoch: bool = True - _reduce_fx: Callable = torch.mean + reduce_fx: Callable = torch.mean enable_graph: bool = False dataloader_idx: Optional[int] = None metric_attribute: Optional[str] = None _sync: Optional[_Sync] = None - @property - def reduce_fx(self) -> Callable: - return self._reduce_fx + def __post_init__(self) -> None: + self._parse_reduce_fx() - @reduce_fx.setter - def reduce_fx(self, reduce_fx: Union[str, Callable]) -> None: + def _parse_reduce_fx(self) -> None: error = ( "Only `self.log(..., reduce_fx={min,max,mean,sum})` are currently supported." " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`." - f" Found: {reduce_fx}" + f" Found: {self.reduce_fx}" ) - self._reduce_fx = reduce_fx - if isinstance(reduce_fx, str): - reduce_fx = reduce_fx.lower() + if isinstance(self.reduce_fx, str): + reduce_fx = self.reduce_fx.lower() if reduce_fx == "avg": reduce_fx = "mean" if reduce_fx not in ("min", "max", "mean", "sum"): raise MisconfigurationException(error) - self._reduce_fx = getattr(torch, reduce_fx) + self.reduce_fx = getattr(torch, reduce_fx) elif self.is_custom_reduction: raise MisconfigurationException(error) @@ -470,12 +472,12 @@ def log( logger=logger, on_step=on_step, on_epoch=on_epoch, + reduce_fx=reduce_fx, enable_graph=enable_graph, dataloader_idx=dataloader_idx, metric_attribute=metric_attribute, ) - meta.reduce_fx = reduce_fx - meta.sync = _Sync(should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only) + meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only) # register logged value if it doesn't exist if key not in self: @@ -521,9 +523,9 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten if not result_metric._computed: # always reduce on epoch end should = result_metric.meta.sync.should - result_metric.meta.sync.set_should(True) + result_metric.meta.sync.should = True result_metric.compute() - result_metric.meta.sync.set_should(should) + result_metric.meta.sync.should = should cache = result_metric._computed if cache is not None and not result_metric.meta.enable_graph: return cache.detach() From 9d2a7856c7c09cf9499bc61500048b6d29248260 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 27 Aug 2021 09:51:47 -0400 Subject: [PATCH 18/27] update --- pytorch_lightning/core/lightning.py | 10 +++------- .../trainer/connectors/logger_connector/result.py | 12 +++--------- pytorch_lightning/utilities/distributed.py | 12 ++++++++---- 3 files changed, 14 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f634ea81370e1..26a4279015c9d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -18,6 +18,7 @@ import logging import numbers import os +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection import tempfile from abc import ABC from contextlib import contextmanager @@ -409,7 +410,7 @@ def log( "You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet." " This is most likely because the model hasn't been passed to the `Trainer`" ) - results = self.trainer._results + results: Optional[ResultCollection] = self.trainer._results if results is None: raise MisconfigurationException( "You are trying to `self.log()` but the loop `ResultCollection` is not registered" @@ -465,11 +466,6 @@ def log( "With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided." ) - if reduce_fx in ("max", "min"): - sync_dist_fn = self.trainer.training_type_plugin.all_gather - else: - sync_dist_fn = self.trainer.training_type_plugin.reduce or sync_ddp - results.log( self._current_fx_name, name, @@ -483,7 +479,7 @@ def log( dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), batch_size=batch_size, sync_dist=sync_dist and distributed_available(), - sync_dist_fn=sync_dist_fn, + sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp, sync_dist_group=sync_dist_group, metric_attribute=metric_attribute, rank_zero_only=rank_zero_only, diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 93581e11f10f4..a011bacfe8cfa 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -195,14 +195,11 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: if self.is_tensor: value = value.float() - self._forward_cache = value # performance: no need to accumulate on values only logged on_step if self.meta.on_step and not self.meta.on_epoch: - value = self.meta.sync(value) - if self.meta.is_max_reduction or self.meta.is_min_reduction: - value = self.meta.reduce_fx(value) - self._forward_cache = self.value = value + self._forward_cache = self.value = self.meta.sync(value) return + self._forward_cache = value # perform accumulation with reduction if self.meta.is_mean_reduction: self.value += value.mean() * batch_size @@ -221,10 +218,7 @@ def compute(self) -> torch.Tensor: if self.meta.is_mean_reduction: cumulated_batch_size = self.meta.sync(self.cumulated_batch_size) return value / cumulated_batch_size - elif self.meta.is_max_reduction or self.meta.is_min_reduction or self.meta.is_sum_reduction: - if value.dim() > 0: - value = self.meta.reduce_fx(value) - return value + return value return self.value.compute() def reset(self) -> None: diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 4f254b6824489..73de501fc9911 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -181,10 +181,14 @@ def sync_ddp( if group is None: group = torch.distributed.group.WORLD - op = reduce_op if isinstance(reduce_op, ReduceOp) else ReduceOp.SUM - - if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): - divide_by_world_size = True + if isinstance(reduce_op, str): + if reduce_op.lower() in ("avg", "mean"): + op = ReduceOp.SUM + divide_by_world_size = True + else: + op = getattr(ReduceOp, reduce_op.upper()) + else: + op = reduce_op # sync all processes before reduction torch.distributed.barrier(group=group) From 95b6a33f4255d507c852da193de8eaa3ea6fbb8b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Aug 2021 13:53:03 +0000 Subject: [PATCH 19/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 26a4279015c9d..bb405c1c6cba9 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -18,7 +18,6 @@ import logging import numbers import os -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection import tempfile from abc import ABC from contextlib import contextmanager @@ -37,6 +36,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ModelIO from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.utilities import _TORCH_SHARDED_TENSOR_AVAILABLE, rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem From 5a4eab2ff91221dbfd4d5f47aa713a5f9726d15c Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 27 Aug 2021 09:53:53 -0400 Subject: [PATCH 20/27] update --- .../connectors/logger_connector/result.py | 49 +++++++++---------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 15310b7e316b0..a011bacfe8cfa 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -45,8 +45,9 @@ class MetricSource(LightningEnum): @dataclass class _Sync: + fn: Optional[Callable] = None - _should: bool = False + should: bool = False rank_zero_only: bool = False op: Optional[str] = None group: Optional[Any] = None @@ -54,14 +55,9 @@ class _Sync: def __post_init__(self) -> None: self._generate_sync_fn() - @property - def should(self) -> bool: - return self._should - - @should.setter - def should(self, should: bool) -> None: - self._should = should - # `self._fn` needs to be re-generated. + def set_should(self, should: bool) -> None: + self.should = should + # when should changes, the `sync fn` need to be re-generated. self._generate_sync_fn() def _generate_sync_fn(self): @@ -69,7 +65,7 @@ def _generate_sync_fn(self): if self.fn is None: self.fn = self.no_op - # save the function as `_fn` as the meta are being re-created + # save the function as `_fn` as the meta are being re-created # and the object references need to match. if self.should and not self.rank_zero_only: kwargs = {"group": self.group} @@ -96,28 +92,31 @@ class _Metadata: logger: bool = True on_step: bool = False on_epoch: bool = True - reduce_fx: Callable = torch.mean + _reduce_fx: Callable = torch.mean enable_graph: bool = False dataloader_idx: Optional[int] = None metric_attribute: Optional[str] = None _sync: Optional[_Sync] = None - def __post_init__(self) -> None: - self._parse_reduce_fx() + @property + def reduce_fx(self) -> Callable: + return self._reduce_fx - def _parse_reduce_fx(self) -> None: + @reduce_fx.setter + def reduce_fx(self, reduce_fx: Union[str, Callable]) -> None: error = ( "Only `self.log(..., reduce_fx={min,max,mean,sum})` are currently supported." " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`." - f" Found: {self.reduce_fx}" + f" Found: {reduce_fx}" ) - if isinstance(self.reduce_fx, str): - reduce_fx = self.reduce_fx.lower() + self._reduce_fx = reduce_fx + if isinstance(reduce_fx, str): + reduce_fx = reduce_fx.lower() if reduce_fx == "avg": reduce_fx = "mean" if reduce_fx not in ("min", "max", "mean", "sum"): raise MisconfigurationException(error) - self.reduce_fx = getattr(torch, reduce_fx) + self._reduce_fx = getattr(torch, reduce_fx) elif self.is_custom_reduction: raise MisconfigurationException(error) @@ -142,7 +141,7 @@ def forked_name(self, on_step: bool) -> str: @property def is_mean_reduction(self) -> bool: - return self.reduce_fx is torch.mean + return self.reduce_fx is (torch.mean) @property def is_sum_reduction(self) -> bool: @@ -466,12 +465,12 @@ def log( logger=logger, on_step=on_step, on_epoch=on_epoch, - reduce_fx=reduce_fx, enable_graph=enable_graph, dataloader_idx=dataloader_idx, metric_attribute=metric_attribute, ) - meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only) + meta.reduce_fx = reduce_fx + meta.sync = _Sync(should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only) # register logged value if it doesn't exist if key not in self: @@ -517,9 +516,9 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten if not result_metric._computed: # always reduce on epoch end should = result_metric.meta.sync.should - result_metric.meta.sync.should = True + result_metric.meta.sync.set_should(True) result_metric.compute() - result_metric.meta.sync.should = should + result_metric.meta.sync.set_should(should) cache = result_metric._computed if cache is not None and not result_metric.meta.enable_graph: return cache.detach() @@ -697,8 +696,8 @@ def load_state_dict( if not metrics: return - - # iterate through result metrics and re-attached Metric references on reload. + + # iterate through result metrics and re-attached Metric references on reload. result_metrics = self.result_metrics for metric_attribute, metric in metrics.items(): for result_metric in result_metrics: From d2c5bc6bbe818356f2a138b21dc7527faa542f1a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Aug 2021 13:55:08 +0000 Subject: [PATCH 21/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../trainer/connectors/logger_connector/result.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index a011bacfe8cfa..dbdcbf71df0d6 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -45,7 +45,7 @@ class MetricSource(LightningEnum): @dataclass class _Sync: - + fn: Optional[Callable] = None should: bool = False rank_zero_only: bool = False @@ -57,7 +57,7 @@ def __post_init__(self) -> None: def set_should(self, should: bool) -> None: self.should = should - # when should changes, the `sync fn` need to be re-generated. + # when should changes, the `sync fn` need to be re-generated. self._generate_sync_fn() def _generate_sync_fn(self): @@ -65,7 +65,7 @@ def _generate_sync_fn(self): if self.fn is None: self.fn = self.no_op - # save the function as `_fn` as the meta are being re-created + # save the function as `_fn` as the meta are being re-created # and the object references need to match. if self.should and not self.rank_zero_only: kwargs = {"group": self.group} @@ -696,8 +696,8 @@ def load_state_dict( if not metrics: return - - # iterate through result metrics and re-attached Metric references on reload. + + # iterate through result metrics and re-attached Metric references on reload. result_metrics = self.result_metrics for metric_attribute, metric in metrics.items(): for result_metric in result_metrics: From eb791d6af6fe840031344f4c9a348a4c9d724d8c Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 27 Aug 2021 09:57:23 -0400 Subject: [PATCH 22/27] update --- .../connectors/logger_connector/result.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index a011bacfe8cfa..d3a7b1828ec58 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -62,22 +62,14 @@ def set_should(self, should: bool) -> None: def _generate_sync_fn(self): """Used to compute the syncing function and cache it.""" - if self.fn is None: - self.fn = self.no_op - - # save the function as `_fn` as the meta are being re-created - # and the object references need to match. - if self.should and not self.rank_zero_only: - kwargs = {"group": self.group} - if "reduce_op" in inspect.signature(self.fn).parameters: - kwargs["reduce_op"] = self.op - self._fn = partial(self.fn, **kwargs) - else: - self._fn = self.no_op + if self.fn and self.should and not self.rank_zero_only: + # save the function as `_fn` as the meta are being re-created + # and the object references need to match. + self._fn = partial(self.fn, reduce_op=self.op, group=self.group) @property def __call__(self) -> Any: - return self._fn + return getattr(self, "_fn", self.no_op) @staticmethod def no_op(value: Any, *_, **__) -> Any: From cdf6438b70d1c5c1515452b1bf83a8b2280125e2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Aug 2021 13:59:08 +0000 Subject: [PATCH 23/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index c8acf8da37f2c..0fd92f948c1b3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -63,7 +63,7 @@ def set_should(self, should: bool) -> None: def _generate_sync_fn(self): """Used to compute the syncing function and cache it.""" if self.fn and self.should and not self.rank_zero_only: - # save the function as `_fn` as the meta are being re-created + # save the function as `_fn` as the meta are being re-created # and the object references need to match. self._fn = partial(self.fn, reduce_op=self.op, group=self.group) From 0d47ec2ae9c67cf5f080592975c123cf7a60d391 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 27 Aug 2021 16:20:36 +0200 Subject: [PATCH 24/27] Push back changes :) --- pytorch_lightning/core/lightning.py | 3 +- .../connectors/logger_connector/result.py | 55 +++++++++---------- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index bb405c1c6cba9..096333388c3b1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -36,7 +36,6 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ModelIO from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.utilities import _TORCH_SHARDED_TENSOR_AVAILABLE, rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -410,7 +409,7 @@ def log( "You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet." " This is most likely because the model hasn't been passed to the `Trainer`" ) - results: Optional[ResultCollection] = self.trainer._results + results = self.trainer._results if results is None: raise MisconfigurationException( "You are trying to `self.log()` but the loop `ResultCollection` is not registered" diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 0fd92f948c1b3..38d09137b33ad 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect from collections.abc import Generator from dataclasses import asdict, dataclass, replace from functools import partial, wraps @@ -45,9 +44,8 @@ class MetricSource(LightningEnum): @dataclass class _Sync: - fn: Optional[Callable] = None - should: bool = False + _should: bool = False rank_zero_only: bool = False op: Optional[str] = None group: Optional[Any] = None @@ -55,21 +53,25 @@ class _Sync: def __post_init__(self) -> None: self._generate_sync_fn() - def set_should(self, should: bool) -> None: - self.should = should - # when should changes, the `sync fn` need to be re-generated. + @property + def should(self) -> bool: + return self._should + + @should.setter + def should(self, should: bool) -> None: + self._should = should + # `self._fn` needs to be re-generated. self._generate_sync_fn() - def _generate_sync_fn(self): + def _generate_sync_fn(self) -> None: """Used to compute the syncing function and cache it.""" - if self.fn and self.should and not self.rank_zero_only: - # save the function as `_fn` as the meta are being re-created - # and the object references need to match. - self._fn = partial(self.fn, reduce_op=self.op, group=self.group) + fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn + # save the function as `_fn` as the meta are being re-created and the object references need to match. + self._fn = partial(fn, reduce_op=self.op, group=self.group) @property def __call__(self) -> Any: - return getattr(self, "_fn", self.no_op) + return self._fn @staticmethod def no_op(value: Any, *_, **__) -> Any: @@ -84,31 +86,28 @@ class _Metadata: logger: bool = True on_step: bool = False on_epoch: bool = True - _reduce_fx: Callable = torch.mean + reduce_fx: Callable = torch.mean enable_graph: bool = False dataloader_idx: Optional[int] = None metric_attribute: Optional[str] = None _sync: Optional[_Sync] = None - @property - def reduce_fx(self) -> Callable: - return self._reduce_fx + def __post_init__(self) -> None: + self._parse_reduce_fx() - @reduce_fx.setter - def reduce_fx(self, reduce_fx: Union[str, Callable]) -> None: + def _parse_reduce_fx(self) -> None: error = ( "Only `self.log(..., reduce_fx={min,max,mean,sum})` are currently supported." " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`." - f" Found: {reduce_fx}" + f" Found: {self.reduce_fx}" ) - self._reduce_fx = reduce_fx - if isinstance(reduce_fx, str): - reduce_fx = reduce_fx.lower() + if isinstance(self.reduce_fx, str): + reduce_fx = self.reduce_fx.lower() if reduce_fx == "avg": reduce_fx = "mean" if reduce_fx not in ("min", "max", "mean", "sum"): raise MisconfigurationException(error) - self._reduce_fx = getattr(torch, reduce_fx) + self.reduce_fx = getattr(torch, reduce_fx) elif self.is_custom_reduction: raise MisconfigurationException(error) @@ -133,7 +132,7 @@ def forked_name(self, on_step: bool) -> str: @property def is_mean_reduction(self) -> bool: - return self.reduce_fx is (torch.mean) + return self.reduce_fx is torch.mean @property def is_sum_reduction(self) -> bool: @@ -457,12 +456,12 @@ def log( logger=logger, on_step=on_step, on_epoch=on_epoch, + reduce_fx=reduce_fx, enable_graph=enable_graph, dataloader_idx=dataloader_idx, metric_attribute=metric_attribute, ) - meta.reduce_fx = reduce_fx - meta.sync = _Sync(should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only) + meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only) # register logged value if it doesn't exist if key not in self: @@ -508,9 +507,9 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten if not result_metric._computed: # always reduce on epoch end should = result_metric.meta.sync.should - result_metric.meta.sync.set_should(True) + result_metric.meta.sync.should = True result_metric.compute() - result_metric.meta.sync.set_should(should) + result_metric.meta.sync.should = should cache = result_metric._computed if cache is not None and not result_metric.meta.enable_graph: return cache.detach() From c5ef4441dc666ee34d8b9c896441333314ae32b2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 27 Aug 2021 16:22:45 +0200 Subject: [PATCH 25/27] Fix test --- tests/core/test_metric_result_integration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 7c6a985d09f33..dd03407080953 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -27,7 +27,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection +from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_7 from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -336,7 +336,7 @@ def on_save_checkpoint(self, checkpoint) -> None: # default sync fn new_results = ResultCollection(False, device) new_results.load_state_dict(state_dict, map_location="cpu") - assert new_results["validation_step.v"].meta.sync.fn == _Sync.no_op + assert new_results["validation_step.v"].meta.sync.fn is None # check map location assert new_results["validation_step.v"].value.device.type == "cpu" From 36963007c9309dc8f69c388be8ab20272ad5555e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 27 Aug 2021 16:27:16 +0200 Subject: [PATCH 26/27] Cache callback metrics --- .../logging_/test_train_loop_logging.py | 33 ++++++++----------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 8d53fa623e109..385870fedd890 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -413,25 +413,20 @@ def test_logging_sync_dist_true(tmpdir, gpus): use_multiple_devices = num_devices > 1 total = fake_result * num_devices + 1 - assert trainer.callback_metrics["foo"] == total if use_multiple_devices else fake_result - assert trainer.callback_metrics["foo_2"] == 2 * num_devices - assert trainer.callback_metrics["foo_3"] == 2 - assert trainer.callback_metrics["foo_4"] == total / num_devices if use_multiple_devices else 1 - assert trainer.callback_metrics["foo_5"] == fake_result * 2 + 1 if use_multiple_devices else fake_result * 2 - - assert trainer.callback_metrics["foo_6"] == fake_result * 3 * 2 + 3 if use_multiple_devices else fake_result * 3 * 2 - assert trainer.callback_metrics["foo_7"] == 2 * num_devices * 3 - assert trainer.callback_metrics["foo_8"] == 2 - assert ( - trainer.callback_metrics["foo_9"] == (fake_result * 2 + 1) / num_devices - if use_multiple_devices - else fake_result - ) - assert trainer.callback_metrics["foo_10"] == 2 - - assert trainer.callback_metrics["bar"] == fake_result * 3 * num_devices - assert trainer.callback_metrics["bar_2"] == fake_result - assert trainer.callback_metrics["bar_3"] == 2 + int(use_multiple_devices) + metrics = trainer.callback_metrics + assert metrics["foo"] == total if use_multiple_devices else fake_result + assert metrics["foo_2"] == 2 * num_devices + assert metrics["foo_3"] == 2 + assert metrics["foo_4"] == total / num_devices if use_multiple_devices else 1 + assert metrics["foo_5"] == fake_result * 2 + 1 if use_multiple_devices else fake_result * 2 + assert metrics["foo_6"] == fake_result * 3 * 2 + 3 if use_multiple_devices else fake_result * 3 * 2 + assert metrics["foo_7"] == 2 * num_devices * 3 + assert metrics["foo_8"] == 2 + assert metrics["foo_9"] == (fake_result * 2 + 1) / num_devices if use_multiple_devices else fake_result + assert metrics["foo_10"] == 2 + assert metrics["bar"] == fake_result * 3 * num_devices + assert metrics["bar_2"] == fake_result + assert metrics["bar_3"] == 2 + int(use_multiple_devices) @RunIf(min_gpus=2, special=True) From 4afa246b8524d6044ffb3440fe6de505e495aaf9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 27 Aug 2021 16:41:55 +0200 Subject: [PATCH 27/27] Fix test --- tests/core/test_results.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_results.py b/tests/core/test_results.py index bc3a35e95c21c..9d164b989f434 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -33,7 +33,7 @@ def _setup_ddp(rank, worldsize): def _ddp_test_fn(rank, worldsize): _setup_ddp(rank, worldsize) tensor = torch.tensor([1.0]) - sync = _Sync(sync_ddp_if_available, should=True, op="SUM") + sync = _Sync(sync_ddp_if_available, _should=True, op="SUM") actual = sync(tensor) assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors"