Skip to content

Commit

Permalink
Fix default logging levels for train step specific hooks (#10756)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 authored Nov 29, 2021
1 parent 088818f commit 753cc4d
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 69 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746))


- Fixed the default logging level for batch hooks associated with training from `on_step=False, on_epoch=True` to `on_step=True, on_epoch=False` ([#10756](https://github.com/PyTorchLightning/pytorch-lightning/pull/10756))



-


Expand Down
21 changes: 4 additions & 17 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,6 @@ def log(
value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor, dict)
)

# set the default depending on the fx_name
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)

if self.trainer is None:
# not an error to support testing the `*_step` methods without a `Trainer` reference
rank_zero_warn(
Expand All @@ -375,7 +371,10 @@ def log(
raise MisconfigurationException(
"You are trying to `self.log()` but it is not managed by the `Trainer` control flow"
)
_FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)

on_step, on_epoch = _FxValidator.check_logging_and_get_default_levels(
self._current_fx_name, on_step=on_step, on_epoch=on_epoch
)

# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
Expand Down Expand Up @@ -530,18 +529,6 @@ def log_grad_norm(self, grad_norm_dict):
"""
self.log_dict(grad_norm_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True)

def __auto_choose_log_on_step(self, on_step: Optional[bool]) -> bool:
if on_step is None:
on_step = False
on_step |= self._current_fx_name in ("training_step", "training_step_end")
return on_step

def __auto_choose_log_on_epoch(self, on_epoch: Optional[bool]) -> bool:
if on_epoch is None:
on_epoch = True
on_epoch &= self._current_fx_name not in ("training_step", "training_step_end")
return on_epoch

def all_gather(
self, data: Union[torch.Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False
):
Expand Down
184 changes: 138 additions & 46 deletions pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union
from typing import Optional, Tuple, Union

from typing_extensions import TypedDict

Expand All @@ -20,50 +20,98 @@

class _FxValidator:
class _LogOptions(TypedDict):
on_step: Union[Tuple[bool], Tuple[bool, bool]]
on_epoch: Union[Tuple[bool], Tuple[bool, bool]]
allowed_on_step: Union[Tuple[bool], Tuple[bool, bool]]
allowed_on_epoch: Union[Tuple[bool], Tuple[bool, bool]]
default_on_step: bool
default_on_epoch: bool

functions = {
"on_before_accelerator_backend_setup": None,
"on_configure_sharded_model": None,
"on_before_backward": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_after_backward": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_before_optimizer_step": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_before_zero_grad": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_before_backward": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_after_backward": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_before_optimizer_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_before_zero_grad": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_init_start": None,
"on_init_end": None,
"on_fit_start": None,
"on_fit_end": None,
"on_sanity_check_start": None,
"on_sanity_check_end": None,
"on_train_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_train_start": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_train_end": None,
"on_validation_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_validation_start": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_validation_end": None,
"on_test_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_test_start": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_test_end": None,
"on_predict_start": None,
"on_predict_end": None,
"on_pretrain_routine_start": None,
"on_pretrain_routine_end": None,
"on_train_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_train_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_validation_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_validation_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_test_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_test_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_train_epoch_start": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_train_epoch_end": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_validation_epoch_start": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_validation_epoch_end": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_test_epoch_start": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_test_epoch_end": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_predict_epoch_start": None,
"on_predict_epoch_end": None,
"on_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_train_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_train_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_validation_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_validation_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_test_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_test_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_epoch_start": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_epoch_end": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_batch_start": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_batch_end": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_train_batch_start": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_train_batch_end": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_validation_batch_start": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
),
"on_validation_batch_end": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
),
"on_test_batch_start": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
),
"on_test_batch_end": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
),
"on_predict_batch_start": None,
"on_predict_batch_end": None,
"on_keyboard_interrupt": None,
Expand All @@ -73,16 +121,34 @@ class _LogOptions(TypedDict):
"setup": None,
"teardown": None,
"configure_sharded_model": None,
"training_step": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"validation_step": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"test_step": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"training_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"validation_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
),
"test_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
),
"predict_step": None,
"training_step_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"validation_step_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"test_step_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"training_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
"validation_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
"test_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
"training_step_end": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"validation_step_end": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
),
"test_step_end": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
),
"training_epoch_end": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"validation_epoch_end": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"test_epoch_end": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"configure_optimizers": None,
"on_train_dataloader": None,
"train_dataloader": None,
Expand All @@ -97,22 +163,48 @@ class _LogOptions(TypedDict):
}

@classmethod
def check_logging(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None:
"""Check if the given function name is allowed to log."""
def check_logging(cls, fx_name: str) -> None:
"""Check if the given hook is allowed to log."""
if fx_name not in cls.functions:
raise RuntimeError(
f"Logging inside `{fx_name}` is not implemented."
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`."
)
allowed = cls.functions[fx_name]
if allowed is None:
raise MisconfigurationException(f"You can't `self.log()` inside `{fx_name}`")

m = "You can't `self.log({}={})` inside `{}`, must be one of {}"
if on_step not in allowed["on_step"]:
msg = m.format("on_step", on_step, fx_name, allowed["on_step"])
if cls.functions[fx_name] is None:
raise MisconfigurationException(f"You can't `self.log()` inside `{fx_name}`.")

@classmethod
def get_default_logging_levels(
cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool]
) -> Tuple[bool, bool]:
"""Return default logging levels for given hook."""
fx_config = cls.functions[fx_name]
assert fx_config is not None
on_step = fx_config["default_on_step"] if on_step is None else on_step
on_epoch = fx_config["default_on_epoch"] if on_epoch is None else on_epoch
return on_step, on_epoch

@classmethod
def check_logging_levels(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None:
"""Check if the logging levels are allowed in the given hook."""
fx_config = cls.functions[fx_name]
assert fx_config is not None
m = "You can't `self.log({}={})` inside `{}`, must be one of {}."
if on_step not in fx_config["allowed_on_step"]:
msg = m.format("on_step", on_step, fx_name, fx_config["allowed_on_step"])
raise MisconfigurationException(msg)

if on_epoch not in allowed["on_epoch"]:
msg = m.format("on_epoch", on_epoch, fx_name, allowed["on_epoch"])
if on_epoch not in fx_config["allowed_on_epoch"]:
msg = m.format("on_epoch", on_epoch, fx_name, fx_config["allowed_on_epoch"])
raise MisconfigurationException(msg)

@classmethod
def check_logging_and_get_default_levels(
cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool]
) -> Tuple[bool, bool]:
"""Check if the given hook name is allowed to log and return logging levels."""
cls.check_logging(fx_name)
on_step, on_epoch = cls.get_default_logging_levels(fx_name, on_step, on_epoch)
cls.check_logging_levels(fx_name, on_step, on_epoch)
return on_step, on_epoch
8 changes: 4 additions & 4 deletions tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,17 @@ def test_fx_validator(tmpdir):
and func_name not in ["on_train_end", "on_test_end", "on_validation_end"]
)
if allowed:
validator.check_logging(fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
validator.check_logging_levels(fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
if not is_start and is_stage:
with pytest.raises(MisconfigurationException, match="must be one of"):
validator.check_logging(fx_name=func_name, on_step=True, on_epoch=on_epoch)
validator.check_logging_levels(fx_name=func_name, on_step=True, on_epoch=on_epoch)
else:
assert func_name in not_supported
with pytest.raises(MisconfigurationException, match="You can't"):
validator.check_logging(fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
validator.check_logging(fx_name=func_name)

with pytest.raises(RuntimeError, match="Logging inside `foo` is not implemented"):
validator.check_logging("foo", False, False)
validator.check_logging("foo")


class HookedCallback(Callback):
Expand Down
Loading

0 comments on commit 753cc4d

Please sign in to comment.