Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix accumulate_grad_batches on init #9652

Merged
merged 13 commits into from
Sep 24, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed back-compatibility for saving hyperparameters from a single container and inferring its argument name by reverting [#9125](https://github.com/PyTorchLightning/pytorch-lightning/pull/9125) ([#9642](https://github.com/PyTorchLightning/pytorch-lightning/pull/9642))


- Fixed `trainer.accumulate_grad_batches` to be an int on init ([#9652](https://github.com/PyTorchLightning/pytorch-lightning/pull/9652))


## [1.4.7] - 2021-09-14

- Fixed logging of nan parameters ([#9364](https://github.com/PyTorchLightning/pytorch-lightning/pull/9364))
Expand Down
24 changes: 8 additions & 16 deletions pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
Expand Down Expand Up @@ -96,7 +95,6 @@ def __init__(
self.autoreport = autoreport
self.autoreport_dir = autoreport_dir
self.poptorch_models = {}
self._original_accumulate_grad_batches: Optional[int] = None
self._training_opts = training_opts
self._inference_opts = inference_opts

Expand Down Expand Up @@ -150,7 +148,7 @@ def _create_opts(self, training: bool) -> "poptorch.Options":
opts = poptorch.Options()
opts.deviceIterations(self.device_iterations)
opts.replicationFactor(self.replication_factor)
gradient_accumulation = self.accumulate_grad_batches if training else 1
gradient_accumulation = self.lightning_module.trainer.accumulate_grad_batches if training else 1
opts.Training.gradientAccumulation(gradient_accumulation)

if os.environ.get("PL_GLOBAL_SEED"):
Expand Down Expand Up @@ -185,27 +183,21 @@ def _convert_to_poptorch_loader(
dataloader = poptorch.DataLoader(**dl_kwargs, options=opts)
return dataloader

@property
def accumulate_grad_batches(self) -> int:
return self._original_accumulate_grad_batches

def _handle_gradient_accumulation_steps(self) -> None:
"""Override the trainer.accumulation_scheduler to act as ``accumulate_grad_batches=1`` if gradient
accumulation has been set.

``optimizer_step`` will be called on every batch, and the IPU will handle grad accumulation internally.
"""
accumulate_grad_batches = self.lightning_module.trainer.accumulate_grad_batches
if not isinstance(accumulate_grad_batches, int):
accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler

if accumulation_scheduler.epochs != [0]:
raise MisconfigurationException(
"IPUs currently only support `Trainer.accumulate_grad_batches` being an integer."
f" Received {accumulate_grad_batches}"
"IPUs currently does not support different `accumulate_grad_batches` at different epoch."
)
# save the original value which will be used to update the global step progress
self._original_accumulate_grad_batches = accumulate_grad_batches
if accumulate_grad_batches > 1:
# TODO(@tchaton): Add support for accumulate_grad_batches being a dictionary
self.lightning_module.trainer.accumulation_scheduler = GradientAccumulationScheduler({0: 1})

# TODO(@tchaton): Add support for accumulate_grad_batches being a dictionary
accumulation_scheduler.scheduling.update({0: 1})
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

@property
def _n_replicate(self):
Expand Down
33 changes: 33 additions & 0 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from pytorch_lightning.callbacks import (
Callback,
GradientAccumulationScheduler,
ModelCheckpoint,
ModelSummary,
ProgressBar,
Expand Down Expand Up @@ -45,6 +46,7 @@ def on_trainer_init(
weights_summary: Optional[str],
stochastic_weight_avg: bool,
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None,
):
# init folder paths for checkpoint + weights save callbacks
self.trainer._default_root_dir = default_root_dir or os.getcwd()
Expand Down Expand Up @@ -87,10 +89,40 @@ def on_trainer_init(
# configure the ModelSummary callback
self._configure_model_summary_callback(weights_summary)

# accumulated grads
self._configure_accumulated_gradients(accumulate_grad_batches)

# push all checkpoint callbacks to the end
# it is important that these are the last callbacks to run
self.trainer.callbacks = self._reorder_callbacks(self.trainer.callbacks)

def _configure_accumulated_gradients(self, accumulate_grad_batches: Union[int, Dict[int, int]]) -> None:
grad_accum_callback = [cb for cb in self.trainer.callbacks if isinstance(cb, GradientAccumulationScheduler)]

if grad_accum_callback:
if accumulate_grad_batches is not None:
raise MisconfigurationException(
"You have set both `accumulate_grad_batches` and passed an "
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
"instance of `GradientAccumulationScheduler` inside callbacks."
)
grad_accum_callback = grad_accum_callback[0]
else:
if accumulate_grad_batches is None:
accumulate_grad_batches = 1

if isinstance(accumulate_grad_batches, dict):
grad_accum_callback = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
grad_accum_callback = GradientAccumulationScheduler({0: accumulate_grad_batches})
else:
raise MisconfigurationException(
f"Gradient accumulation supports only int and dict types, got {accumulate_grad_batches}"
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
)

self.trainer.callbacks.append(grad_accum_callback)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
self.trainer.accumulate_grad_batches = grad_accum_callback.get_accumulate_grad_batches(0)
self.trainer.accumulation_scheduler = grad_accum_callback

def _configure_checkpoint_callbacks(self, checkpoint_callback: bool) -> None:
# TODO: Remove this error in v1.5 so we rely purely on the type signature
if not isinstance(checkpoint_callback, bool):
Expand Down Expand Up @@ -126,6 +158,7 @@ def _configure_model_summary_callback(self, weights_summary: Optional[str] = Non
else:
model_summary = ModelSummary(max_depth=max_depth)
self.trainer.callbacks.append(model_summary)
self.trainer.weights_summary = weights_summary

def _configure_swa_callbacks(self):
if not self.trainer._stochastic_weight_avg:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Union
from typing import Union

from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand All @@ -27,7 +26,6 @@ def on_trainer_init(
gradient_clip_val: float,
gradient_clip_algorithm: str,
track_grad_norm: Union[int, float, str],
accumulate_grad_batches: Union[int, Dict[int, int]],
terminate_on_nan: bool,
):

Expand All @@ -43,16 +41,3 @@ def on_trainer_init(
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != "inf":
raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).")
self.trainer.track_grad_norm = float(track_grad_norm)

# accumulated grads
self.trainer.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)

def configure_accumulated_gradients(self, accumulate_grad_batches: Union[int, Dict[int, int]]) -> None:
if isinstance(accumulate_grad_batches, dict):
self.trainer.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
schedule = {0: accumulate_grad_batches}
self.trainer.accumulation_scheduler = GradientAccumulationScheduler(schedule)
else:
raise TypeError("Gradient accumulation supports only int and dict types")
8 changes: 3 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
track_grad_norm: Union[int, float, str] = -1,
check_val_every_n_epoch: int = 1,
fast_dev_run: Union[int, bool] = False,
accumulate_grad_batches: Union[int, Dict[int, int]] = 1,
accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None,
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
max_epochs: Optional[int] = None,
min_epochs: Optional[int] = None,
max_steps: Optional[int] = None,
Expand Down Expand Up @@ -435,8 +435,6 @@ def __init__(
# default .predict() loop
self.predict_loop = PredictionLoop()

self.weights_summary = weights_summary
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

# Needed because of LightningOptimizer
self._lightning_optimizers = None

Expand All @@ -454,9 +452,10 @@ def __init__(
process_position,
default_root_dir,
weights_save_path,
self.weights_summary,
weights_summary,
stochastic_weight_avg,
max_time,
accumulate_grad_batches,
)

# hook
Expand All @@ -478,7 +477,6 @@ def __init__(
gradient_clip_val,
gradient_clip_algorithm,
track_grad_norm,
accumulate_grad_batches,
terminate_on_nan,
)
self._setup_on_init(num_sanity_val_steps)
Expand Down
4 changes: 2 additions & 2 deletions tests/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,9 @@ def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, da
@RunIf(ipu=True)
def test_accumulate_grad_batches_dict_fails(tmpdir):
model = IPUModel()
trainer = Trainer(default_root_dir=tmpdir, ipus=1, accumulate_grad_batches={0: 1})
trainer = Trainer(default_root_dir=tmpdir, ipus=1, accumulate_grad_batches={1: 2})
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(
MisconfigurationException, match="IPUs currently only support `Trainer.accumulate_grad_batches` being an int"
MisconfigurationException, match="IPUs currently does not support different `accumulate_grad_batches`"
):
trainer.fit(model)

Expand Down
52 changes: 37 additions & 15 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,28 @@ def test_checkpoint_callbacks_are_last(tmpdir):
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2])
cb_connector = CallbackConnector(trainer)
cb_connector._attach_model_callbacks()
assert trainer.callbacks == [progress_bar, lr_monitor, model_summary, checkpoint1, checkpoint2]
assert trainer.callbacks == [
progress_bar,
lr_monitor,
model_summary,
trainer.accumulation_scheduler,
checkpoint1,
checkpoint2,
]

# no model callbacks
model = LightningModule()
model.configure_callbacks = lambda: []
trainer.model = model
cb_connector._attach_model_callbacks()
assert trainer.callbacks == [progress_bar, lr_monitor, model_summary, checkpoint1, checkpoint2]
assert trainer.callbacks == [
progress_bar,
lr_monitor,
model_summary,
trainer.accumulation_scheduler,
checkpoint1,
checkpoint2,
]

# with model-specific callbacks that substitute ones in Trainer
model = LightningModule()
Expand All @@ -57,7 +71,15 @@ def test_checkpoint_callbacks_are_last(tmpdir):
trainer.model = model
cb_connector = CallbackConnector(trainer)
cb_connector._attach_model_callbacks()
assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, model_summary, checkpoint1, checkpoint2]
assert trainer.callbacks == [
progress_bar,
lr_monitor,
trainer.accumulation_scheduler,
early_stopping,
model_summary,
checkpoint1,
checkpoint2,
]


class StatefulCallback0(Callback):
Expand Down Expand Up @@ -118,7 +140,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
def test_attach_model_callbacks():
"""Test that the callbacks defined in the model and through Trainer get merged correctly."""

def assert_composition(trainer_callbacks, model_callbacks, expected):
def _attach_callbacks(trainer_callbacks, model_callbacks):
model = LightningModule()
model.configure_callbacks = lambda: model_callbacks
trainer = Trainer(
Expand All @@ -127,37 +149,37 @@ def assert_composition(trainer_callbacks, model_callbacks, expected):
trainer.model = model
cb_connector = CallbackConnector(trainer)
cb_connector._attach_model_callbacks()
assert trainer.callbacks == expected
return trainer

early_stopping = EarlyStopping()
progress_bar = ProgressBar()
lr_monitor = LearningRateMonitor()
grad_accumulation = GradientAccumulationScheduler({1: 1})

# no callbacks
assert_composition(trainer_callbacks=[], model_callbacks=[], expected=[])
trainer = _attach_callbacks(trainer_callbacks=[], model_callbacks=[])
assert trainer.callbacks == [trainer.accumulation_scheduler]

# callbacks of different types
assert_composition(
trainer_callbacks=[early_stopping], model_callbacks=[progress_bar], expected=[early_stopping, progress_bar]
)
trainer = _attach_callbacks(trainer_callbacks=[early_stopping], model_callbacks=[progress_bar])
assert trainer.callbacks == [early_stopping, trainer.accumulation_scheduler, progress_bar]

# same callback type twice, different instance
assert_composition(
trainer = _attach_callbacks(
trainer_callbacks=[progress_bar, EarlyStopping()],
model_callbacks=[early_stopping],
expected=[progress_bar, early_stopping],
)
assert trainer.callbacks == [progress_bar, trainer.accumulation_scheduler, early_stopping]

# multiple callbacks of the same type in trainer
assert_composition(
trainer = _attach_callbacks(
trainer_callbacks=[LearningRateMonitor(), EarlyStopping(), LearningRateMonitor(), EarlyStopping()],
model_callbacks=[early_stopping, lr_monitor],
expected=[early_stopping, lr_monitor],
)
assert trainer.callbacks == [trainer.accumulation_scheduler, early_stopping, lr_monitor]

# multiple callbacks of the same type, in both trainer and model
assert_composition(
trainer = _attach_callbacks(
trainer_callbacks=[
LearningRateMonitor(),
progress_bar,
Expand All @@ -166,8 +188,8 @@ def assert_composition(trainer_callbacks, model_callbacks, expected):
EarlyStopping(),
],
model_callbacks=[early_stopping, lr_monitor, grad_accumulation, early_stopping],
expected=[progress_bar, early_stopping, lr_monitor, grad_accumulation, early_stopping],
)
assert trainer.callbacks == [progress_bar, early_stopping, lr_monitor, grad_accumulation, early_stopping]


def test_attach_model_callbacks_override_info(caplog):
Expand Down
Loading