From 79183b6294d53bce849fc82ca1a6f999e20dda93 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 18 Feb 2022 15:26:35 +0100 Subject: [PATCH 1/4] Remove `Trainer._strategy_type` --- pytorch_lightning/core/lightning.py | 3 - .../connectors/accelerator_connector.py | 5 - .../trainer/connectors/data_connector.py | 4 +- pytorch_lightning/trainer/trainer.py | 12 +- .../test_accelerator_connector.py | 6 +- tests/trainer/test_data_loading.py | 4 +- tests/trainer/test_trainer.py | 279 +++++------------- 7 files changed, 75 insertions(+), 238 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 098956a703a8a..e9c129c09b57c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -95,9 +95,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # pointer to the trainer object self.trainer = None - self._strategy_type = None - self._device_type = None - # true if using amp self.use_amp: bool = False diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 20c5f485b4e71..62df53e6e454d 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -65,7 +65,6 @@ TPUSpawnStrategy, ) from pytorch_lightning.utilities import ( - _StrategyType, AMPType, device_parser, LightningEnum, @@ -855,7 +854,3 @@ def has_tpu(self) -> bool: @property def use_dp(self) -> bool: return isinstance(self.strategy, DataParallelStrategy) - - @property - def _strategy_type(self) -> _StrategyType: - return self.strategy.strategy_name diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index ef79bd88db822..1bedfed63cd0b 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -22,6 +22,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler +from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -34,7 +35,6 @@ has_iterable_dataset, has_len_all_ranks, ) -from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden @@ -217,7 +217,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: if not isinstance(dataloader, DataLoader): return - using_spawn = self.trainer._accelerator_connector._strategy_type == _StrategyType.DDP_SPAWN + using_spawn = isinstance(self.trainer.strategy, DDPSpawnStrategy) num_cpus = multiprocessing.cpu_count() # ddp_spawn + num_workers > 0 don't mix! tell the user diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6ed5d6c31f719..2c9dce091ae96 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -77,7 +77,6 @@ from pytorch_lightning.utilities import ( _AcceleratorType, _IPU_AVAILABLE, - _StrategyType, _TPU_AVAILABLE, AMPType, device_parser, @@ -1963,10 +1962,6 @@ def should_rank_save_checkpoint(self) -> bool: isinstance(strategy, pl.strategies.TPUSpawnStrategy) and strategy.local_rank == 0 or strategy.is_global_zero ) - @property - def _strategy_type(self) -> str: - return self.strategy.strategy_name - @property def _device_type(self) -> _AcceleratorType: return self._accelerator_connector.device_type @@ -2126,12 +2121,7 @@ def distributed_sampler_kwargs(self) -> Optional[dict]: @property def data_parallel(self) -> bool: - return self._strategy_type in ( - _StrategyType.DP, - _StrategyType.DDP, - _StrategyType.DDP_SPAWN, - _StrategyType.DDP2, - ) + return isinstance(self.strategy, ParallelStrategy) @property def progress_bar_dict(self) -> dict: diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 76fa6d64f5a56..b837fdc6b824c 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -42,7 +42,7 @@ ParallelStrategy, SingleDeviceStrategy, ) -from pytorch_lightning.utilities import _AcceleratorType, _StrategyType +from pytorch_lightning.utilities import _AcceleratorType from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.runif import RunIf @@ -580,11 +580,9 @@ def test_devices_with_cpu_only_supports_integer(): @pytest.mark.parametrize("training_type", ["ddp2", "dp"]) def test_unsupported_strategy_types_on_cpu(training_type): - with pytest.warns(UserWarning, match="is not supported on CPUs, hence setting `strategy='ddp"): trainer = Trainer(accelerator=training_type, num_processes=2) - - assert trainer._strategy_type == _StrategyType.DDP + assert isinstance(trainer.strategy, DDPStrategy) def test_accelerator_ddp_for_cpu(tmpdir): diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index be063aab0bf95..edd63057620ef 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -19,10 +19,10 @@ from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler from pytorch_lightning import Trainer +from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.data import _update_dataloader -from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import PossibleUserWarning from tests.helpers import BoringModel, RandomDataset @@ -133,7 +133,7 @@ def _get_warning_msg(): @pytest.mark.parametrize("num_workers", [0, 1]) def test_dataloader_warnings(tmpdir, num_workers): trainer = Trainer(default_root_dir=tmpdir, accelerator="cpu", devices=2, strategy="ddp_spawn", fast_dev_run=4) - assert trainer._accelerator_connector._strategy_type == _StrategyType.DDP_SPAWN + assert isinstance(trainer.strategy, DDPSpawnStrategy) trainer.fit(TestSpawnBoringModel(num_workers)) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0d2d8bbdc55b6..122c60bf7d971 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -45,9 +45,9 @@ DDPSpawnShardedStrategy, DDPSpawnStrategy, DDPStrategy, + SingleDeviceStrategy, ) from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _AcceleratorType from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.imports import _IS_WINDOWS, _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_8 @@ -1172,95 +1172,44 @@ def val_dataloader(self): assert mocked.call_count == sum(trainer.num_val_batches) -@pytest.mark.parametrize( # TODO: please update tests, @daniellepintz - "trainer_kwargs,expected", +@pytest.mark.parametrize( + ["trainer_kwargs", "strategy_cls", "strategy_name", "_device_type", "num_gpus"], [ - ( - dict(accelerator=None, gpus=None), - dict(_strategy_type="single_device", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(accelerator="dp", gpus=None), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(accelerator="ddp", gpus=None), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(accelerator="ddp", num_processes=2, gpus=None), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(accelerator="ddp", num_nodes=2, gpus=None), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(accelerator="ddp_cpu", num_processes=2, gpus=None), - dict(_strategy_type="ddp_spawn", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(accelerator="ddp2", gpus=None), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(accelerator=None, gpus=1), - dict(_strategy_type="single_device", _device_type=_AcceleratorType.GPU, num_gpus=1), - ), - ( - dict(accelerator="dp", gpus=1), - dict(_strategy_type="dp", _device_type=_AcceleratorType.GPU, num_gpus=1), - ), - ( - dict(accelerator="ddp", gpus=1), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.GPU, num_gpus=1), - ), - ( - dict(accelerator="ddp_cpu", num_processes=2, gpus=1), - dict(_strategy_type="ddp_spawn", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(accelerator="ddp2", gpus=1), - dict(_strategy_type="ddp2", _device_type=_AcceleratorType.GPU, num_gpus=1), - ), - ( - dict(accelerator=None, gpus=2), - dict(_strategy_type="ddp_spawn", _device_type=_AcceleratorType.GPU, num_gpus=2), - ), - ( - dict(accelerator="dp", gpus=2), - dict(_strategy_type="dp", _device_type=_AcceleratorType.GPU, num_gpus=2), - ), - ( - dict(accelerator="ddp", gpus=2), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.GPU, num_gpus=2), - ), - ( - dict(accelerator="ddp2", gpus=2), - dict(_strategy_type="ddp2", _device_type=_AcceleratorType.GPU, num_gpus=2), - ), - ( - dict(accelerator="ddp2", num_processes=2, gpus=None), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(accelerator="dp", num_processes=2, gpus=None), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), + ({"accelerator": None}, SingleDeviceStrategy, "single_device", "cpu", 0), + ({"accelerator": "dp"}, DDPStrategy, "ddp", "cpu", 0), + ({"accelerator": "ddp"}, DDPStrategy, "ddp", "cpu", 0), + ({"accelerator": "ddp", "num_processes": 2}, DDPStrategy, "ddp", "cpu", 0), + ({"accelerator": "ddp", "num_nodes": 2}, DDPStrategy, "ddp", "cpu", 0), + ({"accelerator": "ddp_cpu", "num_processes": 2}, DDPSpawnStrategy, "ddp_spawn", "cpu", 0), + ({"accelerator": "ddp2"}, DDPStrategy, "ddp", "cpu", 0), + ({"accelerator": None, "gpus": 1}, SingleDeviceStrategy, "single_device", "gpu", 1), + ({"accelerator": "dp", "gpus": 1}, DataParallelStrategy, "dp", "gpu", 1), + ({"accelerator": "ddp", "gpus": 1}, DDPStrategy, "ddp", "gpu", 1), + ({"accelerator": "ddp_cpu", "num_processes": 2, "gpus": 1}, DDPSpawnStrategy, "ddp_spawn", "cpu", 0), + ({"accelerator": "ddp2", "gpus": 1}, DDP2Strategy, "ddp2", "gpu", 1), + ({"accelerator": None, "gpus": 2}, DDPSpawnStrategy, "ddp_spawn", "gpu", 2), + ({"accelerator": "dp", "gpus": 2}, DataParallelStrategy, "dp", "gpu", 2), + ({"accelerator": "ddp", "gpus": 2}, DDPStrategy, "ddp", "gpu", 2), + ({"accelerator": "ddp2", "gpus": 2}, DDP2Strategy, "ddp2", "gpu", 2), + ({"accelerator": "ddp2", "num_processes": 2}, DDPStrategy, "ddp", "cpu", 0), + ({"accelerator": "dp", "num_processes": 2}, DDPStrategy, "ddp", "cpu", 0), ], ) -def test_trainer_config(trainer_kwargs, expected, monkeypatch): - if trainer_kwargs["gpus"] is not None: +def test_trainer_config_accelerator(monkeypatch, trainer_kwargs, strategy_cls, strategy_name, _device_type, num_gpus): + if trainer_kwargs.get("gpus") is not None: monkeypatch.setattr(torch.cuda, "is_available", lambda: True) monkeypatch.setattr(torch.cuda, "device_count", lambda: trainer_kwargs["gpus"]) + if trainer_kwargs["accelerator"] in (None, "ddp_cpu"): trainer = Trainer(**trainer_kwargs) else: with pytest.deprecated_call(match=r"accelerator='.*'\)` has been deprecated in v1.5"): trainer = Trainer(**trainer_kwargs) - assert len(expected) == 3 - for k, v in expected.items(): - assert getattr(trainer, k) == v, f"Failed on {trainer_kwargs}, where {k}={ getattr(trainer, k)}, not {v}" + + assert isinstance(trainer.strategy, strategy_cls) + assert strategy_cls.strategy_name == strategy_name + assert getattr(trainer, "_device_type") == _device_type + assert trainer.num_gpus == num_gpus def test_trainer_subclassing(): @@ -2092,140 +2041,48 @@ def training_step(self, batch, batch_idx): trainer.fit(model) -@pytest.mark.parametrize( # TODO: please update tests, @daniellepintz - "trainer_kwargs,expected", +@pytest.mark.parametrize( + ["trainer_kwargs", "strategy_cls", "strategy_name", "_device_type", "num_gpus"], [ - ( - dict(strategy=None, gpus=None), - dict(_strategy_type="single_device", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(strategy="dp", gpus=None), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(strategy="ddp", gpus=None), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(strategy="ddp", num_processes=2, gpus=None), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(strategy="ddp", num_nodes=2, gpus=None), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(strategy="ddp2", gpus=None), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(strategy=None, gpus=1), - dict(_strategy_type="single_device", _device_type=_AcceleratorType.GPU, num_gpus=1), - ), - ( - dict(strategy="dp", gpus=1), - dict(_strategy_type="dp", _device_type=_AcceleratorType.GPU, num_gpus=1), - ), - ( - dict(strategy="ddp", gpus=1), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.GPU, num_gpus=1), - ), - ( - dict(strategy="ddp_spawn", gpus=1), - dict(_strategy_type="ddp_spawn", _device_type=_AcceleratorType.GPU, num_gpus=1), - ), - ( - dict(strategy="ddp2", gpus=1), - dict(_strategy_type="ddp2", _device_type=_AcceleratorType.GPU, num_gpus=1), - ), - ( - dict(strategy=None, gpus=2), - dict(_strategy_type="ddp_spawn", _device_type=_AcceleratorType.GPU, num_gpus=2), - ), - ( - dict(strategy="dp", gpus=2), - dict(_strategy_type="dp", _device_type=_AcceleratorType.GPU, num_gpus=2), - ), - ( - dict(strategy="ddp", gpus=2), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.GPU, num_gpus=2), - ), - ( - dict(strategy="ddp2", gpus=2), - dict(_strategy_type="ddp2", _device_type=_AcceleratorType.GPU, num_gpus=2), - ), - ( - dict(strategy="ddp2", num_processes=2, gpus=None), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(strategy="dp", num_processes=2, gpus=None), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(strategy="ddp_spawn", num_processes=2, gpus=None), - dict(_strategy_type="ddp_spawn", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(strategy="ddp_spawn", num_processes=1, gpus=None), - dict(_strategy_type="ddp_spawn", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(strategy="ddp_fully_sharded", gpus=1), - dict(_strategy_type="ddp_fully_sharded", _device_type=_AcceleratorType.GPU, num_gpus=1), - ), - ( - dict(strategy=DDPSpawnStrategy(), num_processes=2, gpus=None), - dict(_strategy_type="ddp_spawn", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(strategy=DDPSpawnStrategy(), gpus=2), - dict(_strategy_type="ddp_spawn", _device_type=_AcceleratorType.GPU, num_gpus=2), - ), - ( - dict(strategy=DDPStrategy(), num_processes=2, gpus=None), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.CPU, num_gpus=0), - ), - ( - dict(strategy=DDPStrategy(), gpus=2), - dict(_strategy_type="ddp", _device_type=_AcceleratorType.GPU, num_gpus=2), - ), - ( - dict(strategy=DDP2Strategy(), gpus=2), - dict(_strategy_type="ddp2", _device_type=_AcceleratorType.GPU, num_gpus=2), - ), - ( - dict(strategy=DataParallelStrategy(), gpus=2), - dict(_strategy_type="dp", _device_type=_AcceleratorType.GPU, num_gpus=2), - ), - ( - dict(strategy=DDPFullyShardedStrategy(), gpus=2), - dict( - _strategy_type="ddp_fully_sharded", - _device_type=_AcceleratorType.GPU, - num_gpus=2, - ), - ), - ( - dict(strategy=DDPSpawnShardedStrategy(), gpus=2), - dict( - _strategy_type="ddp_sharded_spawn", - _device_type=_AcceleratorType.GPU, - num_gpus=2, - ), - ), - ( - dict(strategy=DDPShardedStrategy(), gpus=2), - dict(_strategy_type="ddp_sharded", _device_type=_AcceleratorType.GPU, num_gpus=2), - ), + ({"strategy": None}, SingleDeviceStrategy, "single_device", "cpu", 0), + ({"strategy": "dp"}, DDPStrategy, "ddp", "cpu", 0), + ({"strategy": "ddp"}, DDPStrategy, "ddp", "cpu", 0), + ({"strategy": "ddp", "num_processes": 2}, DDPStrategy, "ddp", "cpu", 0), + ({"strategy": "ddp", "num_nodes": 2}, DDPStrategy, "ddp", "cpu", 0), + ({"strategy": "ddp2"}, DDPStrategy, "ddp", "cpu", 0), + ({"strategy": None, "gpus": 1}, SingleDeviceStrategy, "single_device", "gpu", 1), + ({"strategy": "dp", "gpus": 1}, DataParallelStrategy, "dp", "gpu", 1), + ({"strategy": "ddp", "gpus": 1}, DDPStrategy, "ddp", "gpu", 1), + ({"strategy": "ddp_spawn", "gpus": 1}, DDPSpawnStrategy, "ddp_spawn", "gpu", 1), + ({"strategy": "ddp2", "gpus": 1}, DDP2Strategy, "ddp2", "gpu", 1), + ({"strategy": None, "gpus": 2}, DDPSpawnStrategy, "ddp_spawn", "gpu", 2), + ({"strategy": "dp", "gpus": 2}, DataParallelStrategy, "dp", "gpu", 2), + ({"strategy": "ddp", "gpus": 2}, DDPStrategy, "ddp", "gpu", 2), + ({"strategy": "ddp2", "gpus": 2}, DDP2Strategy, "ddp2", "gpu", 2), + ({"strategy": "ddp2", "num_processes": 2}, DDPStrategy, "ddp", "cpu", 0), + ({"strategy": "ddp", "num_processes": 2}, DDPStrategy, "ddp", "cpu", 0), + ({"strategy": "ddp_spawn", "num_processes": 2}, DDPSpawnStrategy, "ddp_spawn", "cpu", 0), + ({"strategy": "ddp_spawn", "num_processes": 1}, DDPSpawnStrategy, "ddp_spawn", "cpu", 0), + ({"strategy": "ddp_fully_sharded", "gpus": 1}, DDPFullyShardedStrategy, "ddp_fully_sharded", "gpu", 1), + ({"strategy": DDPSpawnStrategy(), "num_processes": 2}, DDPSpawnStrategy, "ddp_spawn", "cpu", 0), + ({"strategy": DDPSpawnStrategy(), "gpus": 2}, DDPSpawnStrategy, "ddp_spawn", "gpu", 2), + ({"strategy": DDPStrategy()}, DDPStrategy, "ddp", "cpu", 0), + ({"strategy": DDPStrategy(), "gpus": 2}, DDPStrategy, "ddp", "gpu", 2), + ({"strategy": DDP2Strategy(), "gpus": 2}, DDP2Strategy, "ddp2", "gpu", 2), + ({"strategy": DataParallelStrategy(), "gpus": 2}, DataParallelStrategy, "dp", "gpu", 2), + ({"strategy": DDPFullyShardedStrategy(), "gpus": 2}, DDPFullyShardedStrategy, "ddp_fully_sharded", "gpu", 2), + ({"strategy": DDPSpawnShardedStrategy(), "gpus": 2}, DDPSpawnShardedStrategy, "ddp_sharded_spawn", "gpu", 2), + ({"strategy": DDPShardedStrategy(), "gpus": 2}, DDPShardedStrategy, "ddp_sharded", "gpu", 2), ], ) -def test_trainer_config_strategy(trainer_kwargs, expected, monkeypatch): - if trainer_kwargs["gpus"] is not None: +def test_trainer_config_strategy(monkeypatch, trainer_kwargs, strategy_cls, strategy_name, _device_type, num_gpus): + if trainer_kwargs.get("gpus") is not None: monkeypatch.setattr(torch.cuda, "is_available", lambda: True) monkeypatch.setattr(torch.cuda, "device_count", lambda: trainer_kwargs["gpus"]) + trainer = Trainer(**trainer_kwargs) - assert len(expected) == 3 - for k, v in expected.items(): - assert getattr(trainer, k) == v, f"Failed on {trainer_kwargs}, where {k}={ getattr(trainer, k)}, not {v}" + + assert isinstance(trainer.strategy, strategy_cls) + assert strategy_cls.strategy_name == strategy_name + assert getattr(trainer, "_device_type") == _device_type + assert trainer.num_gpus == num_gpus From e21d607c28a8de71431e2b63b5c480d138668cea Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 18 Feb 2022 15:29:18 +0100 Subject: [PATCH 2/4] Minor change --- tests/trainer/test_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 122c60bf7d971..e260210fbd6a5 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1208,7 +1208,7 @@ def test_trainer_config_accelerator(monkeypatch, trainer_kwargs, strategy_cls, s assert isinstance(trainer.strategy, strategy_cls) assert strategy_cls.strategy_name == strategy_name - assert getattr(trainer, "_device_type") == _device_type + assert trainer._device_type == _device_type assert trainer.num_gpus == num_gpus @@ -2084,5 +2084,5 @@ def test_trainer_config_strategy(monkeypatch, trainer_kwargs, strategy_cls, stra assert isinstance(trainer.strategy, strategy_cls) assert strategy_cls.strategy_name == strategy_name - assert getattr(trainer, "_device_type") == _device_type + assert trainer._device_type == _device_type assert trainer.num_gpus == num_gpus From d37ff3e492a95d9eb056e8d0485cd9d5567683d0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 18 Feb 2022 15:31:03 +0100 Subject: [PATCH 3/4] Undo change --- pytorch_lightning/core/lightning.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e9c129c09b57c..76b4890e56c62 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -95,6 +95,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # pointer to the trainer object self.trainer = None + self._device_type = None + # true if using amp self.use_amp: bool = False From 2c75a7bf16d95f01ee7c6683b325b1c55e8c463d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 18 Feb 2022 15:59:42 +0100 Subject: [PATCH 4/4] mypy --- pytorch_lightning/loops/epoch/prediction_epoch_loop.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 4306d4ecb936c..ca3555cd79d68 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -164,7 +164,11 @@ def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]: """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our :class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`.""" # the batch_sampler is not be defined in case of CombinedDataLoaders - batch_sampler = getattr(self.trainer.predict_dataloaders[dataloader_idx], "batch_sampler", None) + batch_sampler = getattr( + self.trainer.predict_dataloaders[dataloader_idx], # type: ignore[has-type] + "batch_sampler", + None, + ) if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions: return batch_sampler.seen_batch_indices