Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Resolve logging reduction when using sync_dist + reduce_fx={mean, max} #9142

Merged
merged 36 commits into from
Aug 27, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
dd80bfa
improve test
tchaton Aug 26, 2021
2d7981f
resolve bug
tchaton Aug 26, 2021
ba50ead
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2021
97ee3f8
update changelog
tchaton Aug 26, 2021
c8cd17e
remove test
tchaton Aug 26, 2021
a1602ae
Merge branch 'logging' of https://github.com/PyTorchLightning/pytorch…
tchaton Aug 26, 2021
226f9a4
update
tchaton Aug 26, 2021
7dc0246
Merge branch 'logging' of https://github.com/PyTorchLightning/pytorch…
tchaton Aug 26, 2021
54df136
improvement
tchaton Aug 26, 2021
6918724
update
tchaton Aug 26, 2021
cb51f34
Merge branch 'logging' of https://github.com/PyTorchLightning/pytorch…
tchaton Aug 26, 2021
ac2a13e
resolve tests
tchaton Aug 26, 2021
4ed928e
update on comments
tchaton Aug 27, 2021
fdbd065
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2021
55238de
resolve test
tchaton Aug 27, 2021
a7c596b
Merge branch 'logging' of https://github.com/PyTorchLightning/pytorch…
tchaton Aug 27, 2021
c87691f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2021
18f148f
resolve typo
tchaton Aug 27, 2021
f7912ba
update
tchaton Aug 27, 2021
662f720
update
tchaton Aug 27, 2021
7f77ba0
Update tests/core/test_metric_result_integration.py
rohitgr7 Aug 27, 2021
b72571d
Refactor and simplify
carmocca Aug 27, 2021
9d2a785
update
tchaton Aug 27, 2021
5c2367c
Merge branch 'logging' of https://github.com/PyTorchLightning/pytorch…
tchaton Aug 27, 2021
95b6a33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2021
5a4eab2
update
tchaton Aug 27, 2021
266da25
Merge branch 'logging' of https://github.com/PyTorchLightning/pytorch…
tchaton Aug 27, 2021
d2c5bc6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2021
eb791d6
update
tchaton Aug 27, 2021
622881b
update
tchaton Aug 27, 2021
cdf6438
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2021
0d47ec2
Push back changes :)
carmocca Aug 27, 2021
c5ef444
Fix test
carmocca Aug 27, 2021
3696300
Cache callback metrics
carmocca Aug 27, 2021
4afa246
Fix test
carmocca Aug 27, 2021
d4177cb
Merge branch 'master' into logging
carmocca Aug 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug in the binary search mode of auto batch size scaling where exception was thrown if the first trainer run resulted in OOM ([#8954](https://github.com/PyTorchLightning/pytorch-lightning/pull/8954))


- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean, max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))


## [1.4.3] - 2021-08-17

- Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861))
Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,11 @@ def log(
"With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided."
)

if reduce_fx in ("max", "min"):
sync_dist_fn = self.trainer.training_type_plugin.all_gather
else:
sync_dist_fn = self.trainer.training_type_plugin.reduce or sync_ddp

results.log(
self._current_fx_name,
name,
Expand All @@ -478,7 +483,7 @@ def log(
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
batch_size=batch_size,
sync_dist=sync_dist and distributed_available(),
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,
sync_dist_fn=sync_dist_fn,
sync_dist_group=sync_dist_group,
metric_attribute=metric_attribute,
rank_zero_only=rank_zero_only,
Expand Down
26 changes: 18 additions & 8 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from collections.abc import Generator
from dataclasses import asdict, dataclass, replace
from functools import partial, wraps
Expand Down Expand Up @@ -51,16 +52,19 @@ class _Sync:
group: Optional[Any] = None

def __post_init__(self) -> None:
if self.fn is None:
if self.fn:
kwargs = dict(group=self.group)
if "reduce_op" in inspect.signature(self.fn).parameters:
kwargs["reduce_op"] = self.op

self.fn_call = partial(self.fn, **kwargs) if self.fn and self.should and not self.rank_zero_only else self.no_op

if not self.fn:
self.fn = self.no_op

@property
def __call__(self) -> Any:
return (
partial(self.fn, reduce_op=self.op, group=self.group)
if self.should and not self.rank_zero_only
else self.no_op
)
return self.fn_call
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def no_op(value: Any, *_, **__) -> Any:
Expand Down Expand Up @@ -124,7 +128,7 @@ def forked_name(self, on_step: bool) -> str:

@property
def is_mean_reduction(self) -> bool:
return self.reduce_fx is torch.mean
return self.reduce_fx is (torch.mean)

@property
def is_sum_reduction(self) -> bool:
Expand Down Expand Up @@ -181,7 +185,10 @@ def update(self, value: _METRIC, batch_size: torch.Tensor) -> None:
self._forward_cache = value
# performance: no need to accumulate on values only logged on_step
if self.meta.on_step and not self.meta.on_epoch:
self.value = self.meta.sync(value)
value = self.meta.sync(value)
if self.meta.is_max_reduction or self.meta.is_min_reduction:
value = self.meta.reduce_fx(value)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self._forward_cache = self.value = value
return
# perform accumulation with reduction
if self.meta.is_mean_reduction:
Expand All @@ -202,6 +209,8 @@ def compute(self) -> torch.Tensor:
cumulated_batch_size = self.meta.sync(self.cumulated_batch_size)
return value / cumulated_batch_size
elif self.meta.is_max_reduction or self.meta.is_min_reduction or self.meta.is_sum_reduction:
if value.dim() > 0:
value = self.meta.reduce_fx(value)
return value
return self.value.compute()

Expand Down Expand Up @@ -462,6 +471,7 @@ def log(

# check the stored metadata and the current one match
elif meta != self[key].meta:
print(meta, self[key].meta)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException(
f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed"
)
Expand Down
80 changes: 61 additions & 19 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,37 +359,79 @@ def get_expected(on_epoch, values):
assert is_included if should_include else not is_included


@pytest.mark.parametrize("gpus", [None, pytest.param(1, marks=RunIf(min_gpus=1))])
class LoggingSyncDistModel(BoringModel):
def __init__(self, fake_result):
super().__init__()
self.fake_result = fake_result

def training_step(self, batch, batch_idx):
value = self.fake_result + self.trainer.global_rank
self.log("foo", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum")
self.log("foo_2", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum")
self.log("foo_3", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean")
self.log("foo_4", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean")
self.log(
"foo_5", batch_idx + self.trainer.global_rank, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="max"
)

self.log("foo_6", value, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
self.log("foo_7", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
self.log("foo_8", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean")
self.log("foo_9", value, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean")
self.log("foo_10", batch_idx, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max")
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
self.log("bar", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
self.log("bar_2", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean")
self.log(
"bar_3", batch_idx + self.trainer.global_rank, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max"
)
return super().validation_step(batch, batch_idx)


@pytest.mark.parametrize(
"gpus", [None, pytest.param(1, marks=RunIf(min_gpus=1)), pytest.param(2, marks=RunIf(min_gpus=2))]
)
def test_logging_sync_dist_true(tmpdir, gpus):
"""
Tests to ensure that the sync_dist flag works (should just return the original value)
"""
fake_result = 1

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
self.log("foo", fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
self.log("foo_2", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
self.log("bar", fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
return super().validation_step(batch, batch_idx)

model = TestModel()
model = LoggingSyncDistModel(fake_result)
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
max_epochs=2,
limit_train_batches=3,
limit_val_batches=3,
weights_summary=None,
gpus=gpus,
)
trainer.fit(model)

assert trainer.logged_metrics["foo"] == fake_result
assert trainer.logged_metrics["foo_2"] == 2
assert trainer.logged_metrics["bar"] == fake_result
num_devices = 1 if gpus is None else gpus
use_multiple_devices = num_devices > 1
total = fake_result * num_devices + 1

assert trainer.callback_metrics["foo"] == total if use_multiple_devices else fake_result
assert trainer.callback_metrics["foo_2"] == 2 * num_devices
assert trainer.callback_metrics["foo_3"] == 2
assert trainer.callback_metrics["foo_4"] == total / num_devices if use_multiple_devices else 1
assert trainer.callback_metrics["foo_5"] == fake_result * 2 + 1 if use_multiple_devices else fake_result * 2

assert trainer.callback_metrics["foo_6"] == fake_result * 3 * 2 + 3 if use_multiple_devices else fake_result * 3 * 2
assert trainer.callback_metrics["foo_7"] == 2 * num_devices * 3
assert trainer.callback_metrics["foo_8"] == 2
assert (
trainer.callback_metrics["foo_9"] == (fake_result * 2 + 1) / num_devices
if use_multiple_devices
else fake_result
)
assert trainer.callback_metrics["foo_10"] == 2

assert trainer.callback_metrics["bar"] == fake_result * 3 * num_devices
assert trainer.callback_metrics["bar_2"] == fake_result
assert trainer.callback_metrics["bar_3"] == 2 + int(use_multiple_devices)


@RunIf(min_gpus=2, special=True)
Expand Down