Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Trainer device properties that uses Trainer._accelerator_connector #12171

Closed
DuYicong515 opened this issue Mar 1, 2022 · 1 comment

Comments

@DuYicong515
Copy link
Contributor

DuYicong515 commented Mar 1, 2022

Proposed refactor

Final state:
On devices related properties, we will keep device_ids, num_devices and num_nodes

Properties below will be deprecated in favor of directly derive from the above 3 directly.
https://github.com/PyTorchLightning/pytorch-lightning/blob/d4d197070fc2c6c04d460bbfb8b1b9d3a2ebc944/pytorch_lightning/trainer/trainer.py#L2029-L2057

https://github.com/PyTorchLightning/pytorch-lightning/blob/d4d197070fc2c6c04d460bbfb8b1b9d3a2ebc944/pytorch_lightning/trainer/trainer.py#L2118-L2120

Motivation

1/ There are bunch of device related properties on Trainer that retrieve the values from Trainer._accelerator_connector. However, those properties should be able to retrieve from Trainer.strategy or Trainer.accelerator. AccleratorConnector are internal-facing, and those properties are not meant to be public.

2/ Some of the properties can be directly derived from the others easily, would also love to deprecate them in favor of deriving from existing ones. This include num_processes, root_gpu, tpu_cores, ipus, num_gpus, data_parallel_device_ids

3/ Trainer.devices currently returns num_devices, which is a bit confusing from its naming, would also love to deprecate it in favor of Trainer.device_ids and Trainer.device_nums

Related discussions in #12126, #11624

Pitch

Kept properties

@property
def device_ids(self) -> List[int]:
    devices = getattr(self.strategy, "parallel_devices", [self.strategy.root_device])
    device_ids = []
    for idx, device in enumerate(devices):
        if isinstance(device, torch.device):
            device_ids.append(device.index or idx)
        elif isinstance(device, int):
            device_ids.append(device)
    return device_ids

@property
def num_devices(self) -> int:
    return len(self.device_ids)

@property    
def num_nodes(self) -> int:
       return getattr(self.strategy, "num_nodes", 1)

The others will be deprecated, and change implementations to derive from the above. Examples

@property
def devices(self) -> Optional[Union[List[int], str, int]]:
    return self._accelerator_connector.devices
    rank_zero_deprecation(
        "`Trainer.devices` was deprecated in v1.6 and will be removed in v1.8."
        " Please use `Trainer.num_devices` or `Trainer.device_ids` to get device information instead."
    )
    return self.num_devices
    
@property
def parallel_device_ids(self) -> List[int]:
    def data_parallel_device_ids(self) -> List[int]:
        rank_zero_deprecation(
            "`Trainer.data_parallel_device_ids` was deprecated in v1.6 and will be removed in v1.8."
            " Please use `self.device_ids if isinstance(self.accelerator, GPUAccelerator) else []` instead"
        )
        :
        return self.device_ids if isinstance(self.accelerator, GPUAccelerator) else []
        
@property 
def ipus(self) -> int: 
     rank_zero_deprecation(
            "`Trainer.ipus` was deprecated in v1.6 and will be removed in v1.8."
            " please use `self.num_devices if isinstance(self.accelerator, IPUAccelerator)` else 0 instead"
     )
    return self.num_devices if isinstance(self.accelerator, IPUAccelerator) else 0

The steps on implementation

1/ Introduce the new properties device_ids and num_devices
2/ Deprecate the others: Will change implementations to directly derive from the existing ones, and add deprecation messages.


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @justusschock @awaelchli @rohitgr7 @kaushikb11 @Borda @ananthsub @ninginthecloud @jjenniferdai @akihironitta

@DuYicong515
Copy link
Contributor Author

cc @four4fish, @carmocca, @awaelchli, @ananthsub

Feel free to share your thoughts on which ones to deprecate and which ones to keep

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants