Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Windows Support for DeepSpeed #8488

Merged
merged 7 commits into from
Jul 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed clearing dataloader references before attaching new dataloaders in consecutive `Trainer.{fit,validate,test,predict}´ runs ([#8442](https://github.com/PyTorchLightning/pytorch-lightning/pull/8442))


- Fixed DeepSpeed Windows support ([#8488](https://github.com/PyTorchLightning/pytorch-lightning/pull/8488))


## [1.3.8] - 2021-07-01

### Fixed
Expand Down
27 changes: 26 additions & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
import os
import platform
from collections import OrderedDict
from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union
Expand All @@ -29,7 +30,7 @@
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning
Expand Down Expand Up @@ -340,6 +341,30 @@ def setup_distributed(self):
self._format_config()
self._config_initialized = True

def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Optional[int] = None) -> None:
if platform.system() != "Windows":
# do not set env variables on windows, allow deepspeed to control setup
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
self._set_node_environment_variables(global_rank, world_size)
log.info(
f"initializing deepspeed distributed: "
f"GLOBAL_RANK: {global_rank}, "
f"MEMBER: {global_rank + 1}/{world_size}"
)
deepspeed.init_distributed(
self.torch_distributed_backend, distributed_port=self.cluster_environment.master_port()
)

def _set_node_environment_variables(
self, global_rank: Optional[int] = None, world_size: Optional[int] = None
) -> None:
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
os.environ["RANK"] = str(global_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_RANK"] = str(self.local_rank)

def pre_dispatch(self):
self.init_deepspeed()
self.barrier()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _compare_version(package: str, op, version) -> bool:

_APEX_AVAILABLE = _module_available("apex.amp")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed')
_DEEPSPEED_AVAILABLE = _module_available('deepspeed')
_FAIRSCALE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and not _IS_WINDOWS and _module_available('fairscale.nn')
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3")
_FAIRSCALE_FULLY_SHARDED_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.4")
Expand Down
36 changes: 36 additions & 0 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
from typing import Any, Dict
from unittest import mock

import pytest
import torch
Expand Down Expand Up @@ -694,6 +695,41 @@ def test_deepspeed_multigpu_test(tmpdir, deepspeed_config):
trainer.test(model)


@RunIf(deepspeed=True)
@mock.patch('deepspeed.init_distributed', autospec=True)
@pytest.mark.parametrize("platform", ["Linux", "Windows"])
def test_deepspeed_plugin_env_variables(mock_deepspeed_distributed, tmpdir, platform):
"""
Test to ensure that we setup distributed communication using correctly.
When using windows, ranks environment variables should not be set, and deepspeed should handle this.
"""
trainer = Trainer(
default_root_dir=tmpdir,
plugins=[DeepSpeedPlugin(stage=3)],
)
plugin = trainer.training_type_plugin
assert isinstance(plugin, DeepSpeedPlugin)
with mock.patch('platform.system', return_value=platform) as mock_platform:
plugin.init_ddp_connection()
mock_deepspeed_distributed.assert_called()
mock_platform.assert_called()
if platform == 'Windows':
# assert no env variables have been set within the DeepSpeedPlugin
assert all(k not in os.environ for k in (
"MASTER_PORT",
"MASTER_ADDR",
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
))
else:
assert os.environ["MASTER_ADDR"] == str(trainer.training_type_plugin.cluster_environment.master_address())
assert os.environ["MASTER_PORT"] == str(trainer.training_type_plugin.cluster_environment.master_port())
assert os.environ["RANK"] == str(trainer.training_type_plugin.global_rank)
assert os.environ["WORLD_SIZE"] == str(trainer.training_type_plugin.world_size)
assert os.environ["LOCAL_RANK"] == str(trainer.training_type_plugin.local_rank)


def _assert_save_model_is_equal(model, tmpdir, trainer, cls=BoringModel):
checkpoint_path = os.path.join(tmpdir, 'model.pt')
trainer.save_checkpoint(checkpoint_path)
Expand Down