Skip to content

Commit

Permalink
Fix support for logging within callbacks returned from `LightningModu…
Browse files Browse the repository at this point in the history
…le` (#10991)

Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
3 people committed Dec 14, 2021
1 parent 6e6ff06 commit dc706e0
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 14 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed a bug where the DeepSpeedPlugin arguments `cpu_checkpointing` and `contiguous_memory_optimization` were not being forwarded to deepspeed correctly ([#10874](https://github.com/PyTorchLightning/pytorch-lightning/issues/10874))


- Fixed an issue with `NeptuneLogger` causing checkpoints to be uploaded with a duplicated file extension ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/issues/11015))
=======


- Fixed support for logging within callbacks returned from `LightningModule` ([#10991](https://github.com/PyTorchLightning/pytorch-lightning/pull/10991))


-


-


## [1.5.5] - 2021-12-07
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn


def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
def verify_loop_configurations(trainer: "pl.Trainer") -> None:
r"""
Checks that the model is configured correctly before the run is started.
Expand All @@ -28,6 +28,10 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule
model: The model to check the configuration.
"""
model = trainer.lightning_module

if trainer.state.fn is None:
raise ValueError("Unexpected: Trainer state fn must be set before validating loop configuration.")
if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
__verify_train_val_loop_configuration(trainer, model)
__verify_manual_optimization_support(trainer, model)
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,11 @@ def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dic
def _trainer_has_checkpoint_callbacks(self):
return len(self.trainer.checkpoint_callbacks) > 0

def attach_model_logging_functions(self, model):
def _attach_model_logging_functions(self):
lightning_module = self.trainer.lightning_module
for callback in self.trainer.callbacks:
callback.log = model.log
callback.log_dict = model.log_dict
callback.log = lightning_module.log
callback.log_dict = lightning_module.log_dict

def _attach_model_callbacks(self) -> None:
"""Attaches the callbacks defined in the model.
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,17 +1114,16 @@ def _run(
if hasattr(model, "hparams"):
parsing.clean_namespace(model.hparams)

verify_loop_configurations(self, model)

# attach model log function to callback
self._callback_connector.attach_model_logging_functions(model)

# attach model to the training type plugin
self.training_type_plugin.connect(model)

self._callback_connector._attach_model_callbacks()
self._callback_connector._attach_model_logging_functions()

verify_loop_configurations(self)

# hook
self._data_connector.prepare_data()
self._callback_connector._attach_model_callbacks()

# ----------------------------
# SET UP TRAINING
Expand Down
8 changes: 4 additions & 4 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,8 +499,8 @@ def training_step(self, batch, batch_idx):
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
dict(name="prepare_data"),
dict(name="configure_callbacks"),
dict(name="prepare_data"),
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
# DeepSpeed needs the batch size to figure out throughput logging
*([dict(name="train_dataloader")] if kwargs.get("strategy") == "deepspeed" else []),
Expand Down Expand Up @@ -618,8 +618,8 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
dict(name="prepare_data"),
dict(name="configure_callbacks"),
dict(name="prepare_data"),
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")),
dict(name="setup", kwargs=dict(stage="fit")),
Expand Down Expand Up @@ -716,8 +716,8 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
dict(name="prepare_data"),
dict(name="configure_callbacks"),
dict(name="prepare_data"),
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage=verb)),
dict(name="setup", kwargs=dict(stage=verb)),
Expand Down Expand Up @@ -748,8 +748,8 @@ def test_trainer_model_hook_system_predict(tmpdir):
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
dict(name="prepare_data"),
dict(name="configure_callbacks"),
dict(name="prepare_data"),
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="predict")),
dict(name="setup", kwargs=dict(stage="predict")),
Expand Down

0 comments on commit dc706e0

Please sign in to comment.