Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Accelerator Connector for Registry #7214

Merged
merged 5 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/plugins_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def register(

def do_register(plugin: Callable) -> Callable:
data["plugin"] = plugin
data["distributed_backend"] = plugin.distributed_backend
self[name] = data
return plugin

Expand Down
38 changes: 24 additions & 14 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ def __init__(
self._training_type_plugin: Optional[TrainingTypePlugin] = None
self._cluster_environment: Optional[ClusterEnvironment] = None

plugins = plugins if plugins is not None else []

if isinstance(plugins, str):
plugins = [plugins]

if not isinstance(plugins, Sequence):
plugins = [plugins]

self.plugins = plugins

# for gpus allow int, string and gpu list
if auto_select_gpus and isinstance(gpus, int):
self.gpus = pick_multiple_gpus(gpus)
Expand All @@ -121,7 +131,7 @@ def __init__(
self.set_distributed_mode()
self.configure_slurm_ddp()

self.handle_given_plugins(plugins)
self.handle_given_plugins()

self.accelerator = self.select_accelerator()

Expand All @@ -148,22 +158,13 @@ def __init__(

self.replace_sampler_ddp = replace_sampler_ddp

def handle_given_plugins(
self, plugins: Optional[Union[ClusterEnvironment, TrainingTypePlugin, PrecisionPlugin, Sequence]]
):
plugins = plugins if plugins is not None else []

if isinstance(plugins, str):
plugins = [plugins]

if not isinstance(plugins, Sequence):
plugins = [plugins]
def handle_given_plugins(self) -> None:

training_type = None
precision = None
cluster_environment = None

for plug in plugins:
for plug in self.plugins:
if isinstance(plug, str) and plug in TrainingTypePluginsRegistry:
if training_type is None:
training_type = TrainingTypePluginsRegistry.get(plug)
Expand All @@ -173,7 +174,7 @@ def handle_given_plugins(
' Found more than 1 training type plugin:'
f' {TrainingTypePluginsRegistry[plug]["plugin"]} registered to {plug}'
)
elif isinstance(plug, str):
if isinstance(plug, str):
# Reset the distributed type as the user has overridden training type
# via the plugins argument
self._distrib_type = None
Expand Down Expand Up @@ -310,6 +311,10 @@ def parallel_devices(self) -> List[Union[torch.device, int]]:
def root_gpu(self) -> Optional[int]:
return self.accelerator.root_device.index if not isinstance(self.accelerator, TPUAccelerator) else None

@property
def is_training_type_in_plugins(self) -> bool:
return any(isinstance(plug, str) and plug in TrainingTypePluginsRegistry for plug in self.plugins)

Comment on lines +314 to +317
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: could you explain why this iterates over all plugins in the connector? from the naming i'd assume this meant to check if a single plugin instance is contained in the registry

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup!

@property
def is_using_torchelastic(self) -> bool:
"""
Expand Down Expand Up @@ -492,7 +497,12 @@ def select_cluster_environment(self) -> ClusterEnvironment:

def set_distributed_mode(self, distributed_backend: Optional[str] = None):

if distributed_backend is not None:
if distributed_backend is None and self.is_training_type_in_plugins:
return

if distributed_backend is not None and distributed_backend in TrainingTypePluginsRegistry:
self.distributed_backend = TrainingTypePluginsRegistry[distributed_backend]["distributed_backend"]
elif distributed_backend is not None:
self.distributed_backend = distributed_backend

if isinstance(self.distributed_backend, Accelerator):
Expand Down
3 changes: 3 additions & 0 deletions tests/plugins/test_plugins_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def test_training_type_plugins_registry_with_new_plugin():

class TestPlugin:

distributed_backend = "test_plugin"

def __init__(self, param1, param2):
self.param1 = param1
self.param2 = param2
Expand All @@ -37,6 +39,7 @@ def __init__(self, param1, param2):
assert plugin_name in TrainingTypePluginsRegistry
assert TrainingTypePluginsRegistry[plugin_name]["description"] == plugin_description
assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == {"param1": "abc", "param2": 123}
assert TrainingTypePluginsRegistry[plugin_name]["distributed_backend"] == "test_plugin"
assert isinstance(TrainingTypePluginsRegistry.get(plugin_name), TestPlugin)

TrainingTypePluginsRegistry.remove(plugin_name)
Expand Down