Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/callback-state
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Apr 17, 2021
2 parents f585a28 + 7b0b0d2 commit e1d518b
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 12 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed multi-node DDP sub-process launch by using `local_rank` instead of `global_rank` for main process assertion ([#7061](https://github.com/PyTorchLightning/pytorch-lightning/pull/7061))


- Fixed incorrect removal of `WORLD_SIZE` environment variable in DDP training when launching with torch distributed/torchelastic ([#6942](https://github.com/PyTorchLightning/pytorch-lightning/pull/6942))


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/__about__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time

_this_year = time.strftime("%Y")
__version__ = '1.3.0rc1'
__version__ = '1.3.0rc2'
__author__ = 'William Falcon et al.'
__author_email__ = '[email protected]'
__license__ = 'Apache-2.0'
Expand Down
7 changes: 1 addition & 6 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,12 +372,7 @@ def setup_precision_plugin(self, plugin: PrecisionPlugin) -> None:

def to_device(self, batch: Any) -> Any:
"""Pushes the batch to the root device"""
# Todo (tchaton) Better fix
is_dict = isinstance(batch, dict)
if is_dict:
batch = [batch]
batch = self.batch_to_device(batch, self.root_device)
return batch[0] if is_dict else batch
return self.batch_to_device(batch, self.root_device)

@property
def amp_backend(self) -> Optional[LightningEnum]:
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/plugins/plugins_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from collections import UserDict
from inspect import getmembers, isclass
from pathlib import Path
from typing import Any, Callable, List, Optional
from typing import Any, Callable, Dict, List, Optional

from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -75,7 +75,7 @@ def register(
" HINT: Use `override=True`."
)

data = {}
data: Dict[str, Any] = {}
data["description"] = description if description is not None else ""

data["init_params"] = init_params
Expand All @@ -90,7 +90,7 @@ def do_register(plugin: Callable) -> Callable:

return do_register

def get(self, name: str) -> Any:
def get(self, name: str, default: Optional[Any] = None) -> Any:
"""
Calls the registered plugin with the required parameters
and returns the plugin object
Expand All @@ -102,6 +102,9 @@ def get(self, name: str) -> Any:
data = self[name]
return data["plugin"](**data["init_params"])

if default is not None:
return default

err_msg = "'{}' not found in registry. Available names: {}"
available_names = ", ".join(sorted(self.keys())) or "none"
raise KeyError(err_msg.format(name, available_names))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def setup_environment(self):
def _call_children_scripts(self):

# bookkeeping of spawned processes
assert self.global_rank == 0
assert self.local_rank == 0
self._check_can_spawn_children()
self._has_spawned_children = True

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def update_global_step(self, total_batch_idx: int, current_global_step: int) ->
return current_global_step

@classmethod
def register_plugins(cls, plugin_registry):
def register_plugins(cls, plugin_registry: Dict) -> None:
plugin_registry.register("deepspeed", cls, description="Default DeepSpeed Plugin")
plugin_registry.register("deepspeed_stage_2", cls, description="DeepSpeed with ZeRO Stage 2 enabled", stage=2)
plugin_registry.register(
Expand Down

0 comments on commit e1d518b

Please sign in to comment.