-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move rank-specific metric state logic to strategies #12193
Comments
Noob questions since I'm still exploring how different strategy works: I think I agree with DDP should own this logic, and the current logic doesn't work with distributed checkpointing where states shouldn't be synced, and it's not necessary for SingleDevice strategy. Since this is the default behaviour, I assumed it's currently used for every strategy. How about the other strategies not mentioned here, |
|
@ananthsub , yes torchmetrics is built on torch.distributed . It probably doesn't support any other backend. We haven't looked into how much effort it would be, but IMO there is currently no need to replicate some of the distributed logic of PL to support different backends (which I think would be required). I wonder though, if we could have something similar to the collectives API that was discussed for PL and only implement it for torch.distributed so that others could at least plug-in their backends. Do you think it is worth investigating that @ananthsub @SkafteNicki or is that too much effort for the little possible gain? |
This is the gap I'm seeing:
One proposal to fix this is remove the check for fault tolerance here. https://github.com/PyTorchLightning/pytorch-lightning/blob/d923dff62763b2096bb2df1b1bbfc7f8540b4bbc/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L409-L411 This would roll out metric checkpointing for all users as part of v1.6 though. This is the only option I currently see. Are there others I'm missing? |
I was wondering why the fault tolerance check was introduced in the first place? Any specific reason we didn't make it default? I think we could also change the current behaviour I feel the metric state is not default to be part of the state dict for checkpointing. What's the reason for it? Any previous use cases that make it not be come part of the state dict? |
Fault Tolerance has a lot of implications around the codebase and it isn't well tested enough yet. As checkpointing is the most critical piece in the codebase, adding wrongly elements to the checkpoints would make reloading legacy checkpoints complex or impossible. The |
To fix this issue, shall we consider reset all states when Also I feel the metric states saving and loading isn't well-tested yet in different scenarios, maybe it's too risky to include it as part of the default state_dict. I feel that's why @tchaton introduces the @ananthsub thoughts? |
Hey @DuYicong515, Yes, that sounds like a good plan. I believe it is the simplest for now. Furthermore, we should have quite a solid test for metric reloading there: https://github.com/PyTorchLightning/pytorch-lightning/blob/581bf7f2f20b770004e866b23505eba216780d2f/tests/core/test_metric_result_integration.py#L374. Might be worth extending to ensure metrics are fully reset on all ranks when Fault Tolerant isn't enabled. Best, |
The metric reset logic was only in Should we add the reset non-0 ranks logic to both? In addition, I think it's better to guard the reset logic with
|
Proposed refactor
Move this Metric state saving/loading logic to data-parallel based strategies
https://github.com/PyTorchLightning/pytorch-lightning/blob/6309a59c3cf93e0bfc352efb7cbf6c50b4544372/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L406-L422
https://github.com/PyTorchLightning/pytorch-lightning/blob/6309a59c3cf93e0bfc352efb7cbf6c50b4544372/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L179-L183
Motivation
Currently, the Trainer hardcodes how metric states are gathered and assumes checkpointing happens from rank 0 only.
This is a bug in cases where distributed checkpointing is required (e.g. DeepSpeed stage3). For distributed checkpointing, we'll save/load the local states per rank, meaning we don't need to sync metric states before saving.
This isn't required for single-device training. We don't need to call
metric.sync()
at all because we're not doing distributed training. From that POV, this is overhead. state gathering of metric states is hardcoded to assume the checkpoint is saved from rank 0 only. This does not work for cases where distributed checkpointing is required.The Strategy is already the place where the LightningModule state dict save/load is customized:
It is also the component that determines on what ranks we save/load checkpoints from.
We can move the metric handling for data-parallel based approaches to be a utility, such that we can share the metric syncing code across relevant strategies (DDP, Sharded, DeepSpeed stage 1/2, FSDP w/o distributed checkpointing, etc).
Pitch
Rewrite the metric logic inside the checkpoint connector as utility functions:
Inside of DDPStrategy:
Additional context
If you enjoy Lightning, check out our other projects! ⚡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @awaelchli @rohitgr7 @akihironitta @ananthsub @ninginthecloud
The text was updated successfully, but these errors were encountered: