Skip to content

Commit

Permalink
[DeepSpeed] Do not fail if batch size could not be inferred for loggi…
Browse files Browse the repository at this point in the history
…ng (#10438)
  • Loading branch information
Sean Naren authored Nov 16, 2021
1 parent 4117028 commit e98ace3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 27 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `monitor` argument in the `EarlyStopping` callback is no longer optional ([#10328](https://github.com/PyTorchLightning/pytorch-lightning/pull/10328))


- Do not fail if batch size could not be inferred for logging when using DeepSpeed ([#10438](https://github.com/PyTorchLightning/pytorch-lightning/issues/10438))


- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))


Expand Down
21 changes: 13 additions & 8 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,11 +618,6 @@ def _format_batch_size_and_grad_accum_config(self):
)
self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
if "train_micro_batch_size_per_gpu" not in self.config:
rank_zero_warn(
"Inferring the batch size for internal deepspeed logging from the `train_dataloader()`. "
"If you require skipping this, please pass "
"`Trainer(strategy=DeepSpeedPlugin(logging_batch_size_per_gpu=batch_size))`"
)
batch_size = self._auto_select_batch_size()
self.config["train_micro_batch_size_per_gpu"] = batch_size
if "gradient_clipping" not in self.config:
Expand All @@ -634,9 +629,19 @@ def _auto_select_batch_size(self):
batch_size = 1
train_dl_source = self.lightning_module.trainer._data_connector._train_dataloader_source
if train_dl_source.is_defined():
train_dataloader = train_dl_source.dataloader()
if hasattr(train_dataloader, "batch_sampler"):
batch_size = train_dataloader.batch_sampler.batch_size
try:
train_dataloader = train_dl_source.dataloader()
if hasattr(train_dataloader, "batch_sampler"):
batch_size = train_dataloader.batch_sampler.batch_size
# broad exception on purpose as `source.dataloader()` will fail if the dataloader requires `setup`
# to have been called before
except Exception:
if self.global_rank == 0:
deepspeed.utils.logging.logger.warning(
"Tried to infer the batch size for internal deepspeed logging from the `train_dataloader()`. "
"To ensure DeepSpeed logging remains correct, please manually pass the plugin with the "
"batch size, `Trainer(strategy=DeepSpeedPlugin(logging_batch_size_per_gpu=batch_size))`."
)
return batch_size

def _format_precision_config(self):
Expand Down
25 changes: 6 additions & 19 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import json
import logging
import os
from typing import Any, Dict, Optional
from unittest import mock
Expand Down Expand Up @@ -872,24 +873,9 @@ def training_step(self, batch, batch_idx):
trainer.fit(model)


@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_deepspeed_warn_train_dataloader_called(tmpdir):
"""Test DeepSpeed warns when it calls ``lightning_module.train_dataloader`` internally for logging batch
size."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
strategy=DeepSpeedPlugin(),
gpus=1,
fast_dev_run=True,
)
with pytest.warns(UserWarning, match="Inferring the batch size for internal deepspeed logging"):
trainer.fit(model)


@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_deepspeed_setup_train_dataloader(tmpdir):
"""Test DeepSpeed works when setup is required to call, and the user passes the batch size manually."""
"""Test DeepSpeed works when setup is required to call in the DataModule."""

class TestSetupIsCalledDataModule(LightningDataModule):
def __init__(self):
Expand All @@ -914,13 +900,14 @@ def test_dataloader(self):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
strategy=DeepSpeedPlugin(logging_batch_size_per_gpu=32),
strategy=DeepSpeedPlugin(logging_level=logging.INFO),
gpus=1,
fast_dev_run=True,
)
dm = TestSetupIsCalledDataModule()
trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)
with mock.patch("deepspeed.utils.logging.logger.warning", autospec=True) as mock_object:
trainer.fit(model, datamodule=dm)
assert any("Tried to infer the batch size" in str(arg) for arg in mock_object.call_args_list)


@mock.patch("torch.optim.lr_scheduler.StepLR.step", autospec=True)
Expand Down

0 comments on commit e98ace3

Please sign in to comment.