Skip to content

Commit

Permalink
Fix self.log(sync_dist=True, reduce_fx={mean,max}) (#9142)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Rohit Gupta <[email protected]>
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
4 people authored and lexierule committed Sep 1, 2021
1 parent 2edd154 commit ceb8bdf
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 51 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [1.4.5] - 2021-08-31

- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))


- Fixed not setting a default value for `max_epochs` if `max_time` was specified on the `Trainer` constructor ([#9072](https://github.com/PyTorchLightning/pytorch-lightning/pull/9072))

## [1.4.4] - 2021-08-24
Expand Down
59 changes: 34 additions & 25 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,33 @@ class MetricSource(LightningEnum):
@dataclass
class _Sync:
fn: Optional[Callable] = None
should: bool = False
_should: bool = False
rank_zero_only: bool = False
op: Optional[str] = None
group: Optional[Any] = None

def __post_init__(self) -> None:
if self.fn is None:
self.fn = self.no_op
self._generate_sync_fn()

@property
def should(self) -> bool:
return self._should

@should.setter
def should(self, should: bool) -> None:
self._should = should
# `self._fn` needs to be re-generated.
self._generate_sync_fn()

def _generate_sync_fn(self) -> None:
"""Used to compute the syncing function and cache it."""
fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn
# save the function as `_fn` as the meta are being re-created and the object references need to match.
self._fn = partial(fn, reduce_op=self.op, group=self.group)

@property
def __call__(self) -> Any:
return (
partial(self.fn, reduce_op=self.op, group=self.group)
if self.should and not self.rank_zero_only
else self.no_op
)
return self._fn

@staticmethod
def no_op(value: Any, *_, **__) -> Any:
Expand All @@ -75,31 +86,28 @@ class _Metadata:
logger: bool = True
on_step: bool = False
on_epoch: bool = True
_reduce_fx: Callable = torch.mean
reduce_fx: Callable = torch.mean
enable_graph: bool = False
dataloader_idx: Optional[int] = None
metric_attribute: Optional[str] = None
_sync: Optional[_Sync] = None

@property
def reduce_fx(self) -> Callable:
return self._reduce_fx
def __post_init__(self) -> None:
self._parse_reduce_fx()

@reduce_fx.setter
def reduce_fx(self, reduce_fx: Union[str, Callable]) -> None:
def _parse_reduce_fx(self) -> None:
error = (
"Only `self.log(..., reduce_fx={min,max,mean,sum})` are currently supported."
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`."
f" Found: {reduce_fx}"
f" Found: {self.reduce_fx}"
)
self._reduce_fx = reduce_fx
if isinstance(reduce_fx, str):
reduce_fx = reduce_fx.lower()
if isinstance(self.reduce_fx, str):
reduce_fx = self.reduce_fx.lower()
if reduce_fx == "avg":
reduce_fx = "mean"
if reduce_fx not in ("min", "max", "mean", "sum"):
raise MisconfigurationException(error)
self._reduce_fx = getattr(torch, reduce_fx)
self.reduce_fx = getattr(torch, reduce_fx)
elif self.is_custom_reduction:
raise MisconfigurationException(error)

Expand Down Expand Up @@ -178,11 +186,11 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
def update(self, value: _METRIC, batch_size: torch.Tensor) -> None:
if self.is_tensor:
value = value.float()
self._forward_cache = value
# performance: no need to accumulate on values only logged on_step
if self.meta.on_step and not self.meta.on_epoch:
self.value = self.meta.sync(value)
self._forward_cache = self.value = self.meta.sync(value)
return
self._forward_cache = value
# perform accumulation with reduction
if self.meta.is_mean_reduction:
self.value += value.mean() * batch_size
Expand All @@ -201,8 +209,7 @@ def compute(self) -> torch.Tensor:
if self.meta.is_mean_reduction:
cumulated_batch_size = self.meta.sync(self.cumulated_batch_size)
return value / cumulated_batch_size
elif self.meta.is_max_reduction or self.meta.is_min_reduction or self.meta.is_sum_reduction:
return value
return value
return self.value.compute()

def reset(self) -> None:
Expand Down Expand Up @@ -448,12 +455,12 @@ def log(
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
dataloader_idx=dataloader_idx,
metric_attribute=metric_attribute,
)
meta.reduce_fx = reduce_fx
meta.sync = _Sync(should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only)
meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only)

# register logged value if it doesn't exist
if key not in self:
Expand Down Expand Up @@ -669,6 +676,8 @@ def load_state_dict(

if not metrics:
return

# iterate through result metrics and re-attached Metric references on reload.
result_metrics = self.result_metrics
for metric_attribute, metric in metrics.items():
for result_metric in result_metrics:
Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,14 @@ def sync_ddp(
if group is None:
group = torch.distributed.group.WORLD

op = reduce_op if isinstance(reduce_op, ReduceOp) else ReduceOp.SUM

if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
divide_by_world_size = True
if isinstance(reduce_op, str):
if reduce_op.lower() in ("avg", "mean"):
op = ReduceOp.SUM
divide_by_world_size = True
else:
op = getattr(ReduceOp, reduce_op.upper())
else:
op = reduce_op

# sync all processes before reduction
torch.distributed.barrier(group=group)
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection
from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -331,7 +331,7 @@ def on_save_checkpoint(self, checkpoint) -> None:
# default sync fn
new_results = ResultCollection(False, device)
new_results.load_state_dict(state_dict, map_location="cpu")
assert new_results["validation_step.v"].meta.sync.fn == _Sync.no_op
assert new_results["validation_step.v"].meta.sync.fn is None

# check map location
assert new_results["validation_step.v"].value.device.type == "cpu"
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _setup_ddp(rank, worldsize):
def _ddp_test_fn(rank, worldsize):
_setup_ddp(rank, worldsize)
tensor = torch.tensor([1.0])
sync = _Sync(sync_ddp_if_available, should=True, op="SUM")
sync = _Sync(sync_ddp_if_available, _should=True, op="SUM")
actual = sync(tensor)
assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors"

Expand Down
75 changes: 56 additions & 19 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,37 +360,74 @@ def get_expected(on_epoch, values):
assert is_included if should_include else not is_included


@pytest.mark.parametrize("gpus", [None, pytest.param(1, marks=RunIf(min_gpus=1))])
class LoggingSyncDistModel(BoringModel):
def __init__(self, fake_result):
super().__init__()
self.fake_result = fake_result

@property
def rank(self) -> int:
return self.trainer.global_rank

def training_step(self, batch, batch_idx):
value = self.fake_result + self.rank
self.log("foo", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum")
self.log("foo_2", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum")
self.log("foo_3", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean")
self.log("foo_4", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean")
self.log("foo_5", batch_idx + self.rank, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="max")

self.log("foo_6", value, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
self.log("foo_7", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
self.log("foo_8", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean")
self.log("foo_9", value, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean")
self.log("foo_10", batch_idx, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max")
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
self.log("bar", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
self.log("bar_2", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean")
self.log("bar_3", batch_idx + self.rank, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max")
return super().validation_step(batch, batch_idx)


@pytest.mark.parametrize(
"gpus", [None, pytest.param(1, marks=RunIf(min_gpus=1)), pytest.param(2, marks=RunIf(min_gpus=2))]
)
def test_logging_sync_dist_true(tmpdir, gpus):
"""
Tests to ensure that the sync_dist flag works (should just return the original value)
"""
fake_result = 1

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
self.log("foo", fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
self.log("foo_2", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
self.log("bar", fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
return super().validation_step(batch, batch_idx)

model = TestModel()
model = LoggingSyncDistModel(fake_result)
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
max_epochs=2,
limit_train_batches=3,
limit_val_batches=3,
weights_summary=None,
gpus=gpus,
)
trainer.fit(model)

assert trainer.logged_metrics["foo"] == fake_result
assert trainer.logged_metrics["foo_2"] == 2
assert trainer.logged_metrics["bar"] == fake_result
num_devices = 1 if gpus is None else gpus
use_multiple_devices = num_devices > 1
total = fake_result * num_devices + 1

metrics = trainer.callback_metrics
assert metrics["foo"] == total if use_multiple_devices else fake_result
assert metrics["foo_2"] == 2 * num_devices
assert metrics["foo_3"] == 2
assert metrics["foo_4"] == total / num_devices if use_multiple_devices else 1
assert metrics["foo_5"] == fake_result * 2 + 1 if use_multiple_devices else fake_result * 2
assert metrics["foo_6"] == fake_result * 3 * 2 + 3 if use_multiple_devices else fake_result * 3 * 2
assert metrics["foo_7"] == 2 * num_devices * 3
assert metrics["foo_8"] == 2
assert metrics["foo_9"] == (fake_result * 2 + 1) / num_devices if use_multiple_devices else fake_result
assert metrics["foo_10"] == 2
assert metrics["bar"] == fake_result * 3 * num_devices
assert metrics["bar_2"] == fake_result
assert metrics["bar_3"] == 2 + int(use_multiple_devices)


@RunIf(min_gpus=2, special=True)
Expand Down

0 comments on commit ceb8bdf

Please sign in to comment.