Skip to content

Commit

Permalink
Remove Trainer._strategy_type (#11990)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Feb 21, 2022
1 parent e15a664 commit 0771a55
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 237 deletions.
1 change: 0 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# pointer to the trainer object
self.trainer = None

self._strategy_type = None
self._device_type = None

# true if using amp
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]:
"""Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our
:class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`."""
# the batch_sampler is not be defined in case of CombinedDataLoaders
batch_sampler = getattr(self.trainer.predict_dataloaders[dataloader_idx], "batch_sampler", None)
batch_sampler = getattr(
self.trainer.predict_dataloaders[dataloader_idx], # type: ignore[has-type]
"batch_sampler",
None,
)
if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions:
return batch_sampler.seen_batch_indices

Expand Down
5 changes: 0 additions & 5 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
TPUSpawnStrategy,
)
from pytorch_lightning.utilities import (
_StrategyType,
AMPType,
device_parser,
LightningEnum,
Expand Down Expand Up @@ -875,7 +874,3 @@ def has_tpu(self) -> bool:
@property
def use_dp(self) -> bool:
return isinstance(self.strategy, DataParallelStrategy)

@property
def _strategy_type(self) -> _StrategyType:
return self.strategy.strategy_name
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pytorch_lightning as pl
from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler
from pytorch_lightning.strategies import DDPSpawnStrategy
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand All @@ -34,7 +35,6 @@
has_iterable_dataset,
has_len_all_ranks,
)
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -217,7 +217,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
if not isinstance(dataloader, DataLoader):
return

using_spawn = self.trainer._accelerator_connector._strategy_type == _StrategyType.DDP_SPAWN
using_spawn = isinstance(self.trainer.strategy, DDPSpawnStrategy)
num_cpus = multiprocessing.cpu_count()

# ddp_spawn + num_workers > 0 don't mix! tell the user
Expand Down
12 changes: 1 addition & 11 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
from pytorch_lightning.utilities import (
_AcceleratorType,
_IPU_AVAILABLE,
_StrategyType,
_TPU_AVAILABLE,
AMPType,
device_parser,
Expand Down Expand Up @@ -1963,10 +1962,6 @@ def should_rank_save_checkpoint(self) -> bool:
isinstance(strategy, pl.strategies.TPUSpawnStrategy) and strategy.local_rank == 0 or strategy.is_global_zero
)

@property
def _strategy_type(self) -> str:
return self.strategy.strategy_name

@property
def _device_type(self) -> _AcceleratorType:
return self._accelerator_connector.device_type
Expand Down Expand Up @@ -2126,12 +2121,7 @@ def distributed_sampler_kwargs(self) -> Optional[dict]:

@property
def data_parallel(self) -> bool:
return self._strategy_type in (
_StrategyType.DP,
_StrategyType.DDP,
_StrategyType.DDP_SPAWN,
_StrategyType.DDP2,
)
return isinstance(self.strategy, ParallelStrategy)

@property
def progress_bar_dict(self) -> dict:
Expand Down
6 changes: 2 additions & 4 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
ParallelStrategy,
SingleDeviceStrategy,
)
from pytorch_lightning.utilities import _AcceleratorType, _StrategyType
from pytorch_lightning.utilities import _AcceleratorType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -580,11 +580,9 @@ def test_devices_with_cpu_only_supports_integer():

@pytest.mark.parametrize("training_type", ["ddp2", "dp"])
def test_unsupported_strategy_types_on_cpu(training_type):

with pytest.warns(UserWarning, match="is not supported on CPUs, hence setting `strategy='ddp"):
trainer = Trainer(accelerator=training_type, num_processes=2)

assert trainer._strategy_type == _StrategyType.DDP
assert isinstance(trainer.strategy, DDPStrategy)


def test_accelerator_ddp_for_cpu(tmpdir):
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler

from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPSpawnStrategy
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.data import _update_dataloader
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from tests.helpers import BoringModel, RandomDataset
Expand Down Expand Up @@ -133,7 +133,7 @@ def _get_warning_msg():
@pytest.mark.parametrize("num_workers", [0, 1])
def test_dataloader_warnings(tmpdir, num_workers):
trainer = Trainer(default_root_dir=tmpdir, accelerator="cpu", devices=2, strategy="ddp_spawn", fast_dev_run=4)
assert trainer._accelerator_connector._strategy_type == _StrategyType.DDP_SPAWN
assert isinstance(trainer.strategy, DDPSpawnStrategy)
trainer.fit(TestSpawnBoringModel(num_workers))


Expand Down
Loading

0 comments on commit 0771a55

Please sign in to comment.