Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite accelerator_connector #11448

Merged
merged 71 commits into from
Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
eb3a03c
Rewrite accelerator_connector
four4fish Jan 12, 2022
50a82d2
update
four4fish Jan 25, 2022
7307969
update
four4fish Jan 25, 2022
3999d80
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2022
c2730f9
update
four4fish Jan 26, 2022
c01aee5
remove print
four4fish Jan 26, 2022
d45eba0
fix more tests
four4fish Jan 27, 2022
ec17b31
change trainer.gpus
four4fish Jan 27, 2022
ffeea28
fix tests
four4fish Jan 28, 2022
57b1642
remove gpu avalible check
four4fish Jan 28, 2022
d374aa9
update
four4fish Jan 28, 2022
0083b69
fix horovod
four4fish Jan 29, 2022
7a5c3ba
fix horovod
four4fish Jan 29, 2022
ca96f84
debug tpu
four4fish Jan 29, 2022
e55a524
fix global rank
four4fish Jan 31, 2022
9996fea
fix horovod
four4fish Feb 1, 2022
a14879c
Update pytorch_lightning/utilities/exceptions.py
four4fish Feb 2, 2022
1626eee
update horovod
four4fish Feb 2, 2022
e134119
address some ananth's comments
four4fish Feb 2, 2022
7c1eb85
draft
four4fish Feb 2, 2022
92deb7e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 2, 2022
8aa1f68
fix ipus and cli tests
four4fish Feb 3, 2022
f4cca3c
fix typo
four4fish Feb 3, 2022
677c6f1
fix tests
four4fish Feb 3, 2022
5351621
fix pre commit
four4fish Feb 3, 2022
836eb98
address comments
four4fish Feb 4, 2022
18c4d9e
rename ttp to strategy
awaelchli Feb 5, 2022
0bbc1c4
fix typo
awaelchli Feb 5, 2022
2d54316
add typing to constructor
awaelchli Feb 5, 2022
f7eee05
update on comments
awaelchli Feb 5, 2022
1022b25
typing, documentation improvements, adding todo's
awaelchli Feb 5, 2022
932e28a
fix amp_level, amp_type mixup
awaelchli Feb 6, 2022
3286de3
more typing fixes
awaelchli Feb 6, 2022
774f35d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 7, 2022
5be85d3
Update tests/models/test_gpu.py
four4fish Feb 7, 2022
f27d01c
Update pytorch_lightning/trainer/connectors/accelerator_connector.py
four4fish Feb 7, 2022
74cbfed
Update pytorch_lightning/trainer/connectors/accelerator_connector.py
four4fish Feb 7, 2022
d54ccfc
Apply suggestions from code review
four4fish Feb 7, 2022
653b5b8
support bagua
four4fish Feb 7, 2022
d000e0d
rename distributed_backend to strategy_name
four4fish Feb 8, 2022
344a5e6
distributed_backend to strategy_name in tests/
four4fish Feb 8, 2022
1707696
Fix tpu tests
kaushikb11 Feb 10, 2022
05a03d0
draft
four4fish Feb 10, 2022
d4c78f8
add device=0 error message and update tests
four4fish Feb 10, 2022
77d2cd1
fix gpu teests
four4fish Feb 10, 2022
d8c5ccc
test revert accelerator auto logic
four4fish Feb 10, 2022
917039f
Tiny update to choose accelerator
kaushikb11 Feb 11, 2022
f2d53fa
fix ipu and gpu tests
four4fish Feb 11, 2022
266d3f8
add special handling for ipustrategy
four4fish Feb 11, 2022
0f833f9
Address comments
four4fish Feb 11, 2022
6b434e2
address comments and add kaushik's suggestions
four4fish Feb 14, 2022
3560c55
Apply suggestions from code review
four4fish Feb 14, 2022
55547bc
fix mypy
four4fish Feb 16, 2022
cc684f1
address comments and fix mypy
four4fish Feb 16, 2022
ce18f52
Updates to attributes
kaushikb11 Feb 17, 2022
88db830
Improve exceptions
kaushikb11 Feb 17, 2022
ee70db8
Updates to attributes
kaushikb11 Feb 17, 2022
c516830
Add utility methods
kaushikb11 Feb 17, 2022
8b07218
Handle zero/empty list values for devices flag
kaushikb11 Feb 17, 2022
caaf390
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
cd12345
Apply suggestions from code review
four4fish Feb 17, 2022
a442852
address comments
four4fish Feb 17, 2022
3152f81
minor comments change
four4fish Feb 17, 2022
2c2e5ac
fix tests
four4fish Feb 17, 2022
71dcb82
Merge branch 'master' into rewrite/acc_con
four4fish Feb 17, 2022
5f32feb
minor fix
four4fish Feb 17, 2022
f2ab1d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
a6ff2c3
add _configure_launcher call to accl_conn
four4fish Feb 17, 2022
62ce92e
Merge branch 'rewrite/acc_con' of https://github.com/four4fish/pytorc…
four4fish Feb 17, 2022
869e571
Apply suggestions from code review
four4fish Feb 17, 2022
9568f3b
Apply suggestions from code review
four4fish Feb 17, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ module = [
"pytorch_lightning.profiler.pytorch",
"pytorch_lightning.profiler.simple",
"pytorch_lightning.trainer.callback_hook",
"pytorch_lightning.trainer.connectors.accelerator_connector",
"pytorch_lightning.trainer.connectors.callback_connector",
"pytorch_lightning.trainer.connectors.data_connector",
"pytorch_lightning.trainer.data_loading",
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import _AcceleratorType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only
Expand Down Expand Up @@ -127,7 +126,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
if not trainer.logger:
raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.")

if trainer._device_type != _AcceleratorType.GPU:
if trainer.strategy.root_device.type != "cuda":
four4fish marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException(
"You are using GPUStatsMonitor but are not running on GPU"
f" since gpus attribute in Trainer is set to {trainer.gpus}."
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
self._check_strategy_support(strategy)
gpu_ids, tpu_cores = _parse_devices(gpus=gpus, auto_select_gpus=False, tpu_cores=tpu_cores)
self._accelerator_connector = AcceleratorConnector(
num_processes=1,
num_processes=None,
devices=devices,
tpu_cores=tpu_cores,
ipus=None,
Expand Down
11 changes: 7 additions & 4 deletions pytorch_lightning/strategies/bagua.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _BAGUA_AVAILABLE
from pytorch_lightning.utilities.seed import reset_seed
Expand Down Expand Up @@ -58,7 +57,7 @@ def __init__(self, pl_module: "pl.LightningModule") -> None:


class BaguaStrategy(DDPStrategy):
distributed_backend = _StrategyType.BAGUA
strategy_name = "bagua"
four4fish marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
Expand Down Expand Up @@ -180,8 +179,12 @@ def _setup_model(self, model: Module) -> BaguaDistributedDataParallel:
)

@classmethod
def register_plugins(cls, plugin_registry: Dict) -> None:
plugin_registry.register("bagua", cls, description="Default Bagua Plugin")
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)

def teardown(self) -> None:
# abort the background communication for async algorithm
Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.seed import reset_seed
Expand All @@ -63,7 +62,7 @@
class DDPStrategy(ParallelStrategy):
"""Strategy for multi-process single-device training on one or multiple nodes."""

distributed_backend = _StrategyType.DDP
strategy_name = "ddp"

def __init__(
self,
Expand Down Expand Up @@ -96,7 +95,6 @@ def __init__(
self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None
self._rank_0_will_call_children_scripts: bool = False
self.set_world_ranks()
ananthsub marked this conversation as resolved.
Show resolved Hide resolved

@property
def is_distributed(self) -> bool:
Expand All @@ -114,7 +112,6 @@ def num_nodes(self) -> int:
def num_nodes(self, num_nodes: int) -> None:
# note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
self._num_nodes = num_nodes
self.set_world_ranks()

@property
def num_processes(self):
Expand Down Expand Up @@ -346,6 +343,11 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
description="DDP Strategy with `find_unused_parameters` as False",
find_unused_parameters=False,
)
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)

def _should_run_deadlock_detection(self) -> bool:
"""Determines whether the plugin will perform process reconciliation in case of errors.
Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/strategies/ddp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict

import torch

from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.types import _METRIC_COLLECTION


class DDP2Strategy(DDPStrategy):
"""DDP2 behaves like DP in one node, but synchronization across nodes behaves like in DDP."""

distributed_backend = _StrategyType.DDP2
strategy_name = "ddp2"

@property
def global_rank(self) -> int:
Expand Down Expand Up @@ -73,3 +74,11 @@ def set_world_ranks(self) -> None:
return
self.cluster_environment.set_global_rank(self.node_rank)
self.cluster_environment.set_world_size(self.num_nodes)

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)
four4fish marked this conversation as resolved.
Show resolved Hide resolved
10 changes: 6 additions & 4 deletions pytorch_lightning/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand All @@ -48,7 +47,7 @@ class DDPSpawnStrategy(ParallelStrategy):
"""Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training
finishes."""

distributed_backend = _StrategyType.DDP_SPAWN
strategy_name = "ddp_spawn"

def __init__(
self,
Expand Down Expand Up @@ -76,7 +75,6 @@ def __init__(
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
self._local_rank = 0
self.set_world_ranks()

@property
def num_nodes(self) -> int:
Expand All @@ -86,7 +84,6 @@ def num_nodes(self) -> int:
def num_nodes(self, num_nodes: int) -> None:
# note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
self._num_nodes = num_nodes
self.set_world_ranks()

@property
def local_rank(self) -> int:
Expand Down Expand Up @@ -264,6 +261,11 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
description="DDPSpawn Strategy with `find_unused_parameters` as False",
find_unused_parameters=False,
)
strategy_registry.register(
four4fish marked this conversation as resolved.
Show resolved Hide resolved
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)

def teardown(self) -> None:
super().teardown()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import log
from pytorch_lightning.utilities.enums import _StrategyType, AMPType, PrecisionType
from pytorch_lightning.utilities.enums import AMPType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -82,7 +82,7 @@ def _move_float_tensors_to_half(self, batch: Any):


class DeepSpeedStrategy(DDPStrategy):
distributed_backend = _StrategyType.DEEPSPEED
strategy_name = "deepspeed"
DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH"

def __init__(
Expand Down
13 changes: 10 additions & 3 deletions pytorch_lightning/strategies/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional

import torch
from torch.nn import DataParallel, Module
Expand All @@ -22,7 +22,6 @@
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, STEP_OUTPUT

Expand All @@ -31,7 +30,7 @@ class DataParallelStrategy(ParallelStrategy):
"""Implements data-parallel training in a single process, i.e., the model gets replicated to each device and
each gets a split of the data."""

distributed_backend = _StrategyType.DP
strategy_name = "dp"

def __init__(
self,
Expand Down Expand Up @@ -149,6 +148,14 @@ def training_step_end(self, output):

return output

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
four4fish marked this conversation as resolved.
Show resolved Hide resolved
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)

def teardown(self) -> None:
super().teardown()
if self.root_device.type == "cuda":
Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/strategies/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT

Expand All @@ -36,7 +36,7 @@

class DDPFullyShardedStrategy(DDPStrategy):

distributed_backend = _StrategyType.DDP_FULLY_SHARDED
strategy_name = "ddp_fully_sharded"

def __init__(
self,
Expand Down Expand Up @@ -212,3 +212,9 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
"fsdp", cls, description="Fully sharded training with checkpointing the full state dict."
)

strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)
13 changes: 10 additions & 3 deletions pytorch_lightning/strategies/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import ExitStack
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -26,7 +26,6 @@
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as dist_group
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only

Expand All @@ -37,7 +36,7 @@
class HorovodStrategy(ParallelStrategy):
"""Plugin for Horovod distributed training integration."""

distributed_backend = _StrategyType.HOROVOD
strategy_name = "horovod"

def __init__(
self,
Expand Down Expand Up @@ -196,6 +195,14 @@ def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tup
opt_params = {p for group in optimizer.param_groups for p in group.get("params", [])}
return [(name, p) for name, p in model.named_parameters() if p in opt_params]

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)

def teardown(self) -> None:
super().teardown()
# teardown may be called before `_exit_stack` is set
Expand Down
12 changes: 11 additions & 1 deletion pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import json
import os
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -62,6 +62,8 @@ def _move_float_tensors_to_half(self, batch: Any) -> Any:
class IPUStrategy(ParallelStrategy):
"""Plugin for training on IPU devices."""

strategy_name = "ipu_strategy"

def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
Expand Down Expand Up @@ -360,3 +362,11 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra

def broadcast(self, obj: object, src: int = 0) -> object:
return obj

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)
9 changes: 7 additions & 2 deletions pytorch_lightning/strategies/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only
Expand All @@ -37,7 +37,7 @@
class DDPShardedStrategy(DDPStrategy):
"""Optimizer and gradient sharded training provided by FairScale."""

distributed_backend = _StrategyType.DDP_SHARDED
strategy_name = "ddp_sharded"
_REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M

def configure_ddp(self) -> None:
Expand Down Expand Up @@ -135,3 +135,8 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
description="DDP Sharded Strategy with `find_unused_parameters` as False",
find_unused_parameters=False,
)
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
)
Loading