Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move rank-specific metric state logic to strategies #12193

Closed
ananthsub opened this issue Mar 2, 2022 · 9 comments · Fixed by #16661
Closed

Move rank-specific metric state logic to strategies #12193

ananthsub opened this issue Mar 2, 2022 · 9 comments · Fixed by #16661
Assignees
Labels
bug Something isn't working checkpointing Related to checkpointing distributed Generic distributed-related topic

Comments

@ananthsub
Copy link
Contributor

ananthsub commented Mar 2, 2022

Proposed refactor

Move this Metric state saving/loading logic to data-parallel based strategies
https://github.com/PyTorchLightning/pytorch-lightning/blob/6309a59c3cf93e0bfc352efb7cbf6c50b4544372/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L406-L422
https://github.com/PyTorchLightning/pytorch-lightning/blob/6309a59c3cf93e0bfc352efb7cbf6c50b4544372/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L179-L183

Motivation

Currently, the Trainer hardcodes how metric states are gathered and assumes checkpointing happens from rank 0 only.

  1. This is a bug in cases where distributed checkpointing is required (e.g. DeepSpeed stage3). For distributed checkpointing, we'll save/load the local states per rank, meaning we don't need to sync metric states before saving.

  2. This isn't required for single-device training. We don't need to call metric.sync() at all because we're not doing distributed training. From that POV, this is overhead. state gathering of metric states is hardcoded to assume the checkpoint is saved from rank 0 only. This does not work for cases where distributed checkpointing is required.

The Strategy is already the place where the LightningModule state dict save/load is customized:

It is also the component that determines on what ranks we save/load checkpoints from.
We can move the metric handling for data-parallel based approaches to be a utility, such that we can share the metric syncing code across relevant strategies (DDP, Sharded, DeepSpeed stage 1/2, FSDP w/o distributed checkpointing, etc).

Pitch

Rewrite the metric logic inside the checkpoint connector as utility functions:

def get_metrics_from_module(module: nn.Module) -> List[torchmetric.Metric]:
    return [m for m in module.modules() if isinstance(m, torchmetric.Metric)]

def sync_metrics(metrics: Iterable[torchmetrics.Metric]) -> None:
    for metric in metrics:
        metric.persistent(True)
        metric.sync()

def unsync_metrics(metrics: Iterable[torchmetrics.Metric]) -> None:
    for metric in metrics:
        # sync can be a no-op (e.g. on cpu) so `unsync` would raise a user error exception if we don't check
        if metric._is_synced:
            metric.unsync()

def reset_metrics(metrics: Iterable[torchmetrics.Metric]) -> None:
    for metric in metrics:
        metric.reset()

Inside of DDPStrategy:

    def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
        self.lightning_module.load_state_dict(checkpoint["state_dict"])
        if not self.is_global_zero:
            metrics = get_metrics_from_module(self.lightning_module)
            reset_metrics(metrics)

    def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
        """Returns model state."""
        model = self.lightning_module
        metrics = get_metrics_from_module(model) if _fault_tolerant_training() else []
        sync_metrics(metrics)
        state_dict = model.state_dict()
        unsync_metrics(metrics)
        return state_dict

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @awaelchli @rohitgr7 @akihironitta @ananthsub @ninginthecloud

@ananthsub ananthsub added distributed Generic distributed-related topic checkpointing Related to checkpointing labels Mar 2, 2022
@DuYicong515
Copy link
Contributor

Noob questions since I'm still exploring how different strategy works: I think I agree with DDP should own this logic, and the current logic doesn't work with distributed checkpointing where states shouldn't be synced, and it's not necessary for SingleDevice strategy.

Since this is the default behaviour, I assumed it's currently used for every strategy. How about the other strategies not mentioned here, dp, horovod, ddp_spawn and the ones we plan to flatten inheritance from ddp to Parallel?

@ananthsub
Copy link
Contributor Author

ananthsub commented Mar 2, 2022

  • This doesn't apply for dp as DataParallel runs in a single process
  • I don't think torchmetrics is compatible with horovod for syncing since torchmetrics is built on top of torch.distributed comms. @SkafteNicki could you confirm?
  • ddp_spawn would require the same changes as ddp - but ideally over time we can collapse DDP and DDPSpawn into the same strategy implementation

@ananthsub ananthsub changed the title Move strategy-specific metrics checkpointing logic to strategies Move rank-specific metric state logic to strategies Mar 2, 2022
@justusschock
Copy link
Member

@ananthsub , yes torchmetrics is built on torch.distributed . It probably doesn't support any other backend. We haven't looked into how much effort it would be, but IMO there is currently no need to replicate some of the distributed logic of PL to support different backends (which I think would be required).

I wonder though, if we could have something similar to the collectives API that was discussed for PL and only implement it for torch.distributed so that others could at least plug-in their backends. Do you think it is worth investigating that @ananthsub @SkafteNicki or is that too much effort for the little possible gain?

@ananthsub ananthsub added the bug Something isn't working label Mar 3, 2022
@ananthsub
Copy link
Contributor Author

This is the gap I'm seeing:

  1. On saving a checkpoint, we conditionally sync the metric states depending on if fault tolerance is enabled. By default, the fault tolerance support is not enabled.

  2. On loading a checkpoint, we unconditionally reset all metric states that are not rank 0. Therefore, the only metric that's re-initialized is the one on rank 0. Again, since fault tolerance is not enabled by default, this means we are only saving and loading the local metric state from rank 0 only.

One proposal to fix this is remove the check for fault tolerance here. https://github.com/PyTorchLightning/pytorch-lightning/blob/d923dff62763b2096bb2df1b1bbfc7f8540b4bbc/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L409-L411

This would roll out metric checkpointing for all users as part of v1.6 though. This is the only option I currently see. Are there others I'm missing?

@DuYicong515
Copy link
Contributor

DuYicong515 commented Mar 3, 2022

One proposal to fix this is remove the check for fault tolerance here.

I was wondering why the fault tolerance check was introduced in the first place? Any specific reason we didn't make it default? I think we could also change the current behaviour if not _fault_tolerant_training() then we reset every metric state, though it loses all previous metric state.

I feel the metric state is not default to be part of the state dict for checkpointing. What's the reason for it? Any previous use cases that make it not be come part of the state dict?

cc @tchaton since you are the original PR author #8641

@tchaton
Copy link
Contributor

tchaton commented Mar 5, 2022

One proposal to fix this is remove the check for fault tolerance here.

I was wondering why the fault tolerance check was introduced in the first place? Any specific reason we didn't make it default? I think we could also change the current behaviour if not _fault_tolerant_training() then we reset every metric state, though it loses all previous metric state.

I feel the metric state is not default to be part of the state dict for checkpointing. What's the reason for it? Any previous use cases that make it not be come part of the state dict?

cc @tchaton since you are the original PR author #8641

Fault Tolerance has a lot of implications around the codebase and it isn't well tested enough yet. As checkpointing is the most critical piece in the codebase, adding wrongly elements to the checkpoints would make reloading legacy checkpoints complex or impossible. The _fault_tolerant_training checks are meant to ensure the checkpoint doesn't get corrupted.

@DuYicong515
Copy link
Contributor

DuYicong515 commented Mar 7, 2022

As checkpointing is the most critical piece in the codebase, adding wrongly elements to the checkpoints would make reloading legacy checkpoints complex or impossible.

To fix this issue, shall we consider reset all states when _fault_tolerant_training is disabled? If we remove the check, it will introduce wrong behaviours on distributed checkpointing scenario.

Also I feel the metric states saving and loading isn't well-tested yet in different scenarios, maybe it's too risky to include it as part of the default state_dict. I feel that's why @tchaton introduces the _fault_tolerant_training in the first place.

@ananthsub thoughts?

@tchaton
Copy link
Contributor

tchaton commented Mar 22, 2022

Hey @DuYicong515,

Yes, that sounds like a good plan. I believe it is the simplest for now.

Furthermore, we should have quite a solid test for metric reloading there: https://github.com/PyTorchLightning/pytorch-lightning/blob/581bf7f2f20b770004e866b23505eba216780d2f/tests/core/test_metric_result_integration.py#L374. Might be worth extending to ensure metrics are fully reset on all ranks when Fault Tolerant isn't enabled.

Best,
T.C

@DuYicong515
Copy link
Contributor

DuYicong515 commented May 5, 2022

The metric reset logic was only in restore_model but not in load_from_checkpoint. So if _fault_tolerant_training is enabled, the checkpointing restoring logic for metrics when using load_from_checkpoint is not correct.

https://github.com/PyTorchLightning/pytorch-lightning/blob/6309a59c3cf93e0bfc352efb7cbf6c50b4544372/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L179-L183

Should we add the reset non-0 ranks logic to both? In addition, I think it's better to guard the reset logic with _fault_tolerant_training() check so it's consistent with metric checkpointing logic.

if _fault_tolerant_training() and not self.trainer.is_global_zero: 
     for module in self.trainer.lightning_module.modules(): 
         if isinstance(module, Metric): 
             module.reset() 

cc @tchaton @SkafteNicki @justusschock

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing distributed Generic distributed-related topic
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants