Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix support for logging within callbacks returned from LightningModule #10991

Merged
merged 11 commits into from
Dec 14, 2021
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where the DeepSpeedPlugin arguments `cpu_checkpointing` and `contiguous_memory_optimization` were not being forwarded to deepspeed correctly ([#10874](https://github.com/PyTorchLightning/pytorch-lightning/issues/10874))


- Fixed support for logging within callbacks returned from `LightningModule` ([#10991](https://github.com/PyTorchLightning/pytorch-lightning/pull/10991))


-


-


Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,11 @@ def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dic
def _trainer_has_checkpoint_callbacks(self):
return len(self.trainer.checkpoint_callbacks) > 0

def attach_model_logging_functions(self, model):
def _attach_model_logging_functions(self):
lightning_module = self.trainer.lightning_module
for callback in self.trainer.callbacks:
callback.log = model.log
callback.log_dict = model.log_dict
callback.log = lightning_module.log
callback.log_dict = lightning_module.log_dict

def _attach_model_callbacks(self) -> None:
"""Attaches the callbacks defined in the model.
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,17 +1091,16 @@ def _run(
if hasattr(model, "hparams"):
parsing.clean_namespace(model.hparams)

verify_loop_configurations(self, model)

# attach model log function to callback
self._callback_connector.attach_model_logging_functions(model)

# attach model to the training type plugin
self.training_type_plugin.connect(model)

self._callback_connector._attach_model_callbacks()
self._callback_connector._attach_model_logging_functions()

verify_loop_configurations(self, model)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

# hook
self._data_connector.prepare_data()
self._callback_connector._attach_model_callbacks()

# ----------------------------
# SET UP TRAINING
Expand Down
8 changes: 4 additions & 4 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,8 @@ def training_step(self, batch, batch_idx):
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
dict(name="prepare_data"),
dict(name="configure_callbacks"),
dict(name="prepare_data"),
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
# DeepSpeed needs the batch size to figure out throughput logging
*([dict(name="train_dataloader")] if kwargs.get("strategy") == "deepspeed" else []),
Expand Down Expand Up @@ -630,8 +630,8 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
dict(name="prepare_data"),
dict(name="configure_callbacks"),
dict(name="prepare_data"),
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")),
dict(name="setup", kwargs=dict(stage="fit")),
Expand Down Expand Up @@ -715,8 +715,8 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
dict(name="prepare_data"),
dict(name="configure_callbacks"),
dict(name="prepare_data"),
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage=verb)),
dict(name="setup", kwargs=dict(stage=verb)),
Expand Down Expand Up @@ -747,8 +747,8 @@ def test_trainer_model_hook_system_predict(tmpdir):
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
dict(name="prepare_data"),
dict(name="configure_callbacks"),
dict(name="prepare_data"),
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="predict")),
dict(name="setup", kwargs=dict(stage="predict")),
Expand Down