Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always use trainer.call_hook #8498

Merged
merged 36 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
ee140af
Always use `trainer.call_hook`
carmocca Jul 20, 2021
7106af0
Fix breakage
carmocca Jul 20, 2021
a370bac
Update CHANGELOG
carmocca Jul 20, 2021
46116ee
Check the exact matches
carmocca Jul 20, 2021
c5e2667
Better args
carmocca Jul 20, 2021
07503a1
Docs and FIXME
carmocca Jul 20, 2021
fe1585d
Remove extra assertion
carmocca Jul 21, 2021
a02cfac
Resolve model connection
carmocca Jul 21, 2021
f855055
Fixes
carmocca Jul 21, 2021
b0788a1
Remove no lightning module check
carmocca Jul 21, 2021
12658fe
Check callable
carmocca Jul 21, 2021
f432fca
Fix mock test
carmocca Jul 21, 2021
b066f6f
Add pl_module to call_hook
carmocca Jul 21, 2021
7149a65
Do not check is overridden
carmocca Jul 21, 2021
a735468
yapf
carmocca Jul 21, 2021
1b21b7d
Merge branch 'master' into bugfix/call-hook-standardization
carmocca Jul 23, 2021
15bbcfc
Fix test
carmocca Jul 23, 2021
7a5c093
Fix test
carmocca Jul 23, 2021
be55e5b
Fix test
carmocca Jul 23, 2021
1b88471
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 23, 2021
eebe803
Merge branch 'master' into bugfix/call-hook-standardization
carmocca Aug 5, 2021
f5b3ea8
Minor changes
carmocca Aug 5, 2021
8390f20
Remove reference
carmocca Aug 6, 2021
2edeb4f
Keep is_overridden check
carmocca Aug 6, 2021
b3bed31
Ooops
carmocca Aug 6, 2021
1b58d54
Merge branch 'master' into bugfix/call-hook-standardization
carmocca Aug 6, 2021
9a35831
Remove stack inspection
carmocca Aug 6, 2021
6cb74ff
Merge branch 'master' into bugfix/call-hook-standardization
carmocca Aug 6, 2021
47ca3bf
Revert test changes. Remove mock usage
carmocca Aug 6, 2021
8172cc9
Check no trainer
carmocca Aug 6, 2021
dc7e12e
Wording
carmocca Aug 6, 2021
8123c43
Period
carmocca Aug 6, 2021
395221f
Merge branch 'master' into bugfix/call-hook-standardization
carmocca Aug 9, 2021
87e81e2
Revert "Keep is_overridden check"
carmocca Aug 9, 2021
7b1ae02
Merge branch 'master' into bugfix/call-hook-standardization
carmocca Aug 10, 2021
e97c094
Merge branch 'master' into bugfix/call-hook-standardization
carmocca Aug 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- The accelerator and training type plugin `setup` hooks no longer have a `model` argument ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))


- Improve coverage of `self.log`-ing in any `LightningModule` or `Callback` hook ([#8498](https://github.com/PyTorchLightning/pytorch-lightning/pull/8498))


- Removed restrictions in the trainer that loggers can only log from rank 0. Existing logger behavior has not changed. ([#8608]
(https://github.com/PyTorchLightning/pytorch-lightning/pull/8608))

Expand Down
17 changes: 15 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,22 @@ def log(
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)

if self.trainer is None:
raise MisconfigurationException(
"You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet."
" This is most likely because the model hasn't been passed to the `Trainer`"
)
results = self.trainer._results
assert results is not None
assert self._current_fx_name is not None
if results is None:
raise MisconfigurationException(
"You are trying to `self.log()` but the loop `ResultCollection` is not registered"
" yet. This is most likely because you are trying to log in a `predict` hook,"
" but it doesn't support logging"
)
if self._current_fx_name is None:
raise MisconfigurationException(
"You are trying to `self.log()` but it is not managed by the `Trainer` control flow"
)
FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)

# make sure user doesn't introduce logic for multi-dataloaders
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,10 @@ def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:

def on_evaluation_model_eval(self) -> None:
"""Sets model to eval mode"""
model_ref = self.trainer.lightning_module
if self.trainer.testing:
model_ref.on_test_model_eval()
self.trainer.call_hook("on_test_model_eval")
else:
model_ref.on_validation_model_eval()
self.trainer.call_hook("on_validation_model_eval")

def on_evaluation_model_train(self) -> None:
"""Sets model to train mode"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _attach_model_callbacks(self) -> None:
In addition, all :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks
will be pushed to the end of the list, ensuring they run last.
"""
model_callbacks = self.trainer.lightning_module.configure_callbacks()
model_callbacks = self.trainer.call_hook("configure_callbacks")
if not model_callbacks:
return
model_callback_types = {type(c) for c in model_callbacks}
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def prepare_data(self) -> None:
if self.can_prepare_data():
if self.trainer.datamodule is not None:
self.trainer.datamodule.prepare_data()
self.trainer.lightning_module.prepare_data()
self.trainer.call_hook("prepare_data")
self.trainer._is_data_prepared = True

def can_prepare_data(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,25 +77,30 @@ class FxValidator:
training_epoch_end=dict(on_step=(False,), on_epoch=(True,)),
validation_epoch_end=dict(on_step=(False,), on_epoch=(True,)),
test_epoch_end=dict(on_step=(False,), on_epoch=(True,)),
on_before_batch_transfer=None,
transfer_batch_to_device=None,
on_after_batch_transfer=None,
backward=None,
optimizer_step=None,
# TODO(@carmocca): some {step,epoch}_{start,end} are missing
configure_optimizers=None,
on_train_dataloader=None,
train_dataloader=None,
on_val_dataloader=None,
val_dataloader=None,
on_test_dataloader=None,
test_dataloader=None,
prepare_data=None,
configure_callbacks=None,
on_validation_model_eval=None,
on_test_model_eval=None,
)

@classmethod
def check_logging(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None:
"""Check if the given function name is allowed to log"""
if fx_name not in cls.functions:
raise RuntimeError(
f"You are trying to `self.log()` inside `{fx_name}` but it is not implemented."
f"Logging inside `{fx_name}` is not implemented."
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
)
allowed = cls.functions[fx_name]
if allowed is None:
raise MisconfigurationException(f"{fx_name} function doesn't support logging using `self.log()`")
raise MisconfigurationException(f"You can't `self.log()` inside `{fx_name}`")

m = "You can't `self.log({}={})` inside `{}`, must be one of {}"
if on_step not in allowed["on_step"]:
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,9 @@ def request_dataloader(
Returns:
The dataloader
"""
self.call_hook(f"on_{stage}_dataloader")
dataloader = getattr(model, f"{stage}_dataloader")()
hook = f"{stage}_dataloader"
self.call_hook("on_" + hook, pl_module=model)
dataloader = self.call_hook(hook, pl_module=model)
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.accelerator.barrier("get_dataloaders")
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ class TrainerOptimizersMixin(ABC):

_lightning_optimizers: Optional[List[LightningOptimizer]]

def init_optimizers(self, model: "pl.LightningModule") -> Tuple[List, List, List]:
def init_optimizers(self, model: Optional["pl.LightningModule"]) -> Tuple[List, List, List]:
pl_module = self.lightning_module or model
self._lightning_optimizers = None
optim_conf = model.configure_optimizers()
optim_conf = self.call_hook("configure_optimizers", pl_module=pl_module)
if optim_conf is None:
rank_zero_warn(
"`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
Expand Down Expand Up @@ -95,7 +96,7 @@ def init_optimizers(self, model: "pl.LightningModule") -> Tuple[List, List, List
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)

is_manual_optimization = not model.automatic_optimization
is_manual_optimization = not pl_module.automatic_optimization
lr_schedulers = self.configure_schedulers(lr_schedulers, monitor, is_manual_optimization)
_validate_scheduler_optimizer(optimizers, lr_schedulers)

Expand Down
58 changes: 25 additions & 33 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,20 +1026,14 @@ def _pre_training_routine(self):
# --------------------------
# Pre-train
# --------------------------
# on pretrain routine start
ref_model = self.lightning_module

self.on_pretrain_routine_start()
ref_model.on_pretrain_routine_start()
self.call_hook("on_pretrain_routine_start")

# print model summary
if self.is_global_zero and self.weights_summary is not None and not self.testing:
max_depth = ModelSummary.MODES[self.weights_summary]
summarize(ref_model, max_depth=max_depth)
summarize(self.lightning_module, max_depth=max_depth)

# on pretrain routine end
self.on_pretrain_routine_end()
ref_model.on_pretrain_routine_end()
self.call_hook("on_pretrain_routine_end")

def _run_train(self) -> None:
self._pre_training_routine()
Expand Down Expand Up @@ -1123,8 +1117,7 @@ def _run_sanity_check(self, ref_model):
stage = self.state.stage
self.sanity_checking = True

# hook and callback
self.on_sanity_check_start()
self.call_hook("on_sanity_check_start")

# reload dataloaders
self._evaluation_loop.reload_evaluation_dataloaders()
Expand All @@ -1133,7 +1126,7 @@ def _run_sanity_check(self, ref_model):
with torch.no_grad():
self._evaluation_loop.run()

self.on_sanity_check_end()
self.call_hook("on_sanity_check_end")

# reset validation metrics
self.logger_connector.reset()
Expand Down Expand Up @@ -1189,8 +1182,7 @@ def _call_setup_hook(self) -> None:

if self.datamodule is not None:
self.datamodule.setup(stage=fn)
self.setup(stage=fn)
self.lightning_module.setup(stage=fn)
self.call_hook("setup", stage=fn)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

self.accelerator.barrier("post_setup")

Expand All @@ -1203,8 +1195,8 @@ def _call_configure_sharded_model(self) -> None:
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook:
with self.accelerator.model_sharded_context():
model.configure_sharded_model()
self.on_configure_sharded_model()
self.call_hook("configure_sharded_model")
self.call_hook("on_configure_sharded_model")
model.call_configure_sharded_model_hook = True
self.accelerator.call_configure_sharded_model_hook = False

Expand All @@ -1213,9 +1205,7 @@ def _call_teardown_hook(self) -> None:

if self.datamodule is not None:
self.datamodule.teardown(stage=fn)

self.teardown(stage=fn)
self.lightning_module.teardown(stage=fn)
self.call_hook("teardown", stage=fn)

self.lightning_module._current_fx_name = None
self.lightning_module._current_dataloader_idx = None
Expand All @@ -1230,38 +1220,40 @@ def _call_teardown_hook(self) -> None:
# summarize profile results
self.profiler.describe()

def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
if self.lightning_module:
prev_fx_name = self.lightning_module._current_fx_name
self.lightning_module._current_fx_name = hook_name
def call_hook(
self, hook_name: str, *args: Any, pl_module: Optional["pl.LightningModule"] = None, **kwargs: Any
) -> Any:
pl_module = self.lightning_module or pl_module
if pl_module:
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name

# always profile hooks
with self.profiler.profile(hook_name):

# first call trainer hook
if hasattr(self, hook_name):
trainer_hook = getattr(self, hook_name)
trainer_hook(*args, **kwargs)
callback_fx = getattr(self, hook_name, None)
if callable(callback_fx):
callback_fx(*args, **kwargs)

# next call hook in lightningModule
output = None
model_ref = self.lightning_module
if is_overridden(hook_name, model_ref):
hook_fx = getattr(model_ref, hook_name)
output = hook_fx(*args, **kwargs)
model_fx = getattr(pl_module, hook_name, None)
if callable(model_fx):
output = model_fx(*args, **kwargs)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# call the accelerator hook
if hasattr(self.accelerator, hook_name):
if hook_name not in ("setup", "teardown") and hasattr(self.accelerator, hook_name):
accelerator_hook = getattr(self.accelerator, hook_name)
accelerator_output = accelerator_hook(*args, **kwargs)
# Rely on the accelerator output if lightningModule hook returns nothing
# Required for cases such as DataParallel where we reduce the output for the user
# todo: move this data parallel logic into the data parallel plugin
output = accelerator_output if output is None else output

if self.lightning_module:
if pl_module:
# restore current_fx when nested context
self.lightning_module._current_fx_name = prev_fx_name
pl_module._current_fx_name = prev_fx_name

return output

Expand Down
3 changes: 0 additions & 3 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,8 @@
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock)
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
def test_can_prepare_data(local_rank, node_rank):

model = BoringModel()
dm = BoringDataModule()
trainer = Trainer()
trainer.model = model
trainer.datamodule = dm

# 1 no DM
Expand Down
27 changes: 15 additions & 12 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from unittest.mock import Mock

import torch

from pytorch_lightning import Callback, Trainer
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import (
EarlyStopping,
GradientAccumulationScheduler,
Expand All @@ -36,18 +35,22 @@ def test_checkpoint_callbacks_are_last(tmpdir):
lr_monitor = LearningRateMonitor()
progress_bar = ProgressBar()

# no model callbacks
model = Mock()
model.configure_callbacks.return_value = []
# no model reference
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2])
trainer.model = model
cb_connector = CallbackConnector(trainer)
cb_connector._attach_model_callbacks()
assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2]

# no model callbacks
model = LightningModule()
model.configure_callbacks = lambda: []
trainer.model = model
cb_connector._attach_model_callbacks()
assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2]

# with model-specific callbacks that substitute ones in Trainer
model = Mock()
model.configure_callbacks.return_value = [checkpoint1, early_stopping, checkpoint2]
model = LightningModule()
model.configure_callbacks = lambda: [checkpoint1, early_stopping, checkpoint2]
trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)])
trainer.model = model
cb_connector = CallbackConnector(trainer)
Expand Down Expand Up @@ -89,8 +92,8 @@ def test_attach_model_callbacks():
"""Test that the callbacks defined in the model and through Trainer get merged correctly."""

def assert_composition(trainer_callbacks, model_callbacks, expected):
model = Mock()
model.configure_callbacks.return_value = model_callbacks
model = LightningModule()
model.configure_callbacks = lambda: model_callbacks
trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks)
trainer.model = model
cb_connector = CallbackConnector(trainer)
Expand Down Expand Up @@ -140,8 +143,8 @@ def assert_composition(trainer_callbacks, model_callbacks, expected):

def test_attach_model_callbacks_override_info(caplog):
"""Test that the logs contain the info about overriding callbacks returned by configure_callbacks."""
model = Mock()
model.configure_callbacks.return_value = [LearningRateMonitor(), EarlyStopping()]
model = LightningModule()
model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping()]
trainer = Trainer(checkpoint_callback=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()])
trainer.model = model
cb_connector = CallbackConnector(trainer)
Expand Down
Loading