-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add strategy
argument to Trainer
#8597
Changes from all commits
98c8830
8b2f6c4
73120fa
e856db1
59ebe70
b5094ec
4d11ced
c878f1c
f9a1a63
ba6bc88
3fe1e67
ae1395c
fff3385
127e5e8
bf7b9cb
aeca5aa
91caf4b
3110ce9
ddbeab4
e6eb8a6
a5c1978
ea03cff
c9cd9f4
0f00172
8b63157
353434d
87a3b60
aee59fd
9ba9c13
52a04ab
8df86ec
6a7d9d4
89c190f
dfecb4f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,6 +94,7 @@ def __init__( | |
ipus, | ||
distributed_backend, | ||
accelerator, | ||
strategy: Optional[Union[str, TrainingTypePlugin]], | ||
gpus, | ||
gpu_ids, | ||
num_nodes, | ||
|
@@ -111,12 +112,9 @@ def __init__( | |
self._distrib_type = None | ||
self._accelerator_type = None | ||
|
||
if distributed_backend is not None: | ||
rank_zero_deprecation( | ||
f"`Trainer(distributed_backend={distributed_backend})` has been deprecated and will be removed in v1.5." | ||
f" Use `Trainer(accelerator={distributed_backend})` instead." | ||
) | ||
distributed_backend = distributed_backend or accelerator | ||
self.strategy = strategy.lower() if isinstance(strategy, str) else strategy | ||
self.distributed_backend = distributed_backend or accelerator | ||
|
||
self._init_deterministic(deterministic) | ||
|
||
self.num_processes = num_processes | ||
|
@@ -126,7 +124,6 @@ def __init__( | |
self.parallel_device_ids = gpu_ids | ||
self.tpu_cores = tpu_cores | ||
self.ipus = ipus | ||
self.distributed_backend = distributed_backend | ||
self.num_nodes = num_nodes | ||
self.sync_batchnorm = sync_batchnorm | ||
self.benchmark = benchmark | ||
|
@@ -151,16 +148,23 @@ def __init__( | |
|
||
self.plugins = plugins | ||
|
||
self._handle_accelerator_and_distributed_backend(distributed_backend, accelerator) | ||
|
||
self._validate_accelerator_and_devices() | ||
|
||
self._warn_if_devices_flag_ignored() | ||
|
||
self.select_accelerator_type() | ||
self.set_distributed_mode() | ||
|
||
if self.strategy is not None: | ||
self._set_training_type_plugin() | ||
else: | ||
self.set_distributed_mode() | ||
self.configure_slurm_ddp() | ||
|
||
self.handle_given_plugins() | ||
self.update_device_type_if_ipu_plugin() | ||
self.update_device_type_if_training_type_plugin_passed() | ||
|
||
self._validate_accelerator_type() | ||
self._set_devices_if_none() | ||
|
@@ -228,11 +232,11 @@ def select_accelerator_type(self) -> None: | |
self._set_devices_to_cpu_num_processes() | ||
self._accelerator_type = DeviceType.CPU | ||
|
||
if self.distributed_backend in ["auto"] + list(DeviceType): | ||
if self.distributed_backend in self.accelerator_types: | ||
self.distributed_backend = None | ||
|
||
def _validate_accelerator_and_devices(self) -> None: | ||
if self.distributed_backend not in ["auto"] + list(DeviceType) and self.devices is not None: | ||
if self.distributed_backend not in self.accelerator_types and self.devices is not None: | ||
raise MisconfigurationException( | ||
f"You passed `devices={self.devices}` but haven't specified" | ||
" `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu')` for the devices mapping," | ||
|
@@ -285,9 +289,56 @@ def _set_devices_if_none(self) -> None: | |
elif self._accelerator_type == DeviceType.CPU: | ||
self.devices = self.num_processes | ||
|
||
def _handle_accelerator_and_distributed_backend( | ||
self, distributed_backend: Optional[str], accelerator: Optional[Union[str, Accelerator]] | ||
) -> None: | ||
if distributed_backend is not None: | ||
rank_zero_deprecation( | ||
f"`Trainer(distributed_backend={distributed_backend})` has been deprecated and will be removed in v1.5." | ||
f" Use `Trainer(strategy={distributed_backend})` instead." | ||
) | ||
if self.strategy is not None: | ||
raise MisconfigurationException( | ||
f"You have passed `Trainer(strategy={self.strategy})` but have" | ||
f" also passed `Trainer(distributed_backend={distributed_backend})`." | ||
f"HINT: Use just `Trainer(strategy={self.strategy})` instead." | ||
) | ||
|
||
if accelerator is not None and accelerator in list(DistributedType): | ||
rohitgr7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rank_zero_deprecation( | ||
f"Passing {accelerator} `strategy` to the `accelerator` flag in Trainer has been deprecated" | ||
f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={accelerator})` instead." | ||
) | ||
Comment on lines
+308
to
+311
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought we weren't going to deprecate the previous There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think accelerator flag is still there.. it's just that passing one of the strategies to it is deprecated. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I understood from our offline discussion was that support There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Flags |
||
if self.strategy is not None: | ||
raise MisconfigurationException( | ||
f"You have passed `Trainer(strategy={self.strategy})` but have" | ||
f" also passed `Trainer(accelerator={accelerator})`." | ||
f"HINT: Use just `Trainer(strategy={self.strategy})` instead." | ||
) | ||
|
||
def _set_training_type_plugin(self) -> None: | ||
if isinstance(self.strategy, str) and self.strategy in TrainingTypePluginsRegistry: | ||
self._training_type_plugin = TrainingTypePluginsRegistry.get(self.strategy) | ||
if isinstance(self.strategy, str): | ||
self.set_distributed_mode(self.strategy) | ||
elif isinstance(self.strategy, TrainingTypePlugin): | ||
self._training_type_plugin = self.strategy | ||
|
||
def handle_given_plugins(self) -> None: | ||
|
||
training_type = None | ||
for plug in self.plugins: | ||
if self.strategy is not None and self._is_plugin_training_type(plug): | ||
raise MisconfigurationException( | ||
f"You have passed `Trainer(strategy={self.strategy})`" | ||
f" and you can only specify one training type plugin, but you have passed {plug} as a plugin." | ||
) | ||
if self._is_plugin_training_type(plug): | ||
rank_zero_deprecation( | ||
f"Passing {plug} `strategy` to the `plugins` flag in Trainer has been deprecated" | ||
f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plug})` instead." | ||
) | ||
|
||
training_type = self._training_type_plugin or None | ||
kaushikb11 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
checkpoint = None | ||
precision = None | ||
cluster_environment = None | ||
|
@@ -350,6 +401,10 @@ def handle_given_plugins(self) -> None: | |
self._checkpoint_io = checkpoint | ||
self._cluster_environment = cluster_environment or self.select_cluster_environment() | ||
|
||
@property | ||
def accelerator_types(self) -> List[str]: | ||
return ["auto"] + list(DeviceType) | ||
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@property | ||
def precision_plugin(self) -> PrecisionPlugin: | ||
if self._precision_plugin is None: | ||
|
@@ -540,9 +595,18 @@ def root_gpu(self) -> Optional[int]: | |
else None | ||
) | ||
|
||
@staticmethod | ||
def _is_plugin_training_type(plugin: Union[str, TrainingTypePlugin]) -> bool: | ||
if isinstance(plugin, str) and (plugin in TrainingTypePluginsRegistry or plugin in list(DistributedType)): | ||
return True | ||
return isinstance(plugin, TrainingTypePlugin) | ||
|
||
@property | ||
def is_training_type_in_plugins(self) -> bool: | ||
return any(isinstance(plug, str) and plug in TrainingTypePluginsRegistry for plug in self.plugins) | ||
return any( | ||
(isinstance(plug, str) and plug in TrainingTypePluginsRegistry) or isinstance(plug, TrainingTypePlugin) | ||
for plug in self.plugins | ||
) | ||
|
||
def select_precision_plugin(self) -> PrecisionPlugin: | ||
# set precision type | ||
|
@@ -875,6 +939,25 @@ def update_device_type_if_ipu_plugin(self) -> None: | |
if isinstance(self._training_type_plugin, IPUPlugin) and self._device_type != DeviceType.IPU: | ||
self._device_type = DeviceType.IPU | ||
|
||
def update_device_type_if_training_type_plugin_passed(self) -> None: | ||
if isinstance(self.strategy, TrainingTypePlugin) or any( | ||
isinstance(plug, TrainingTypePlugin) for plug in self.plugins | ||
): | ||
if self._accelerator_type is not None: | ||
if self.use_ipu: | ||
self._device_type = DeviceType.IPU | ||
elif self.use_tpu: | ||
self._device_type = DeviceType.TPU | ||
elif self.use_gpu: | ||
self._device_type = DeviceType.GPU | ||
else: | ||
if self.has_ipu: | ||
self._device_type = DeviceType.IPU | ||
elif self.has_tpu: | ||
self._device_type = DeviceType.TPU | ||
elif self.has_gpu: | ||
self._device_type = DeviceType.GPU | ||
kaushikb11 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def configure_slurm_ddp(self): | ||
# extract SLURM flag vars | ||
# whenever we have the correct number of tasks, we let slurm manage processes | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing whitespace here