-
Notifications
You must be signed in to change notification settings - Fork 414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add EvaluationDistributedSampler
and examples on distributed evaluation
#1886
Conversation
for more information, see https://pre-commit.ci
EvaluationDistributedSampler
and examples on distributed evaluation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #1886 +/- ##
======================================
- Coverage 87% 87% -0%
======================================
Files 270 270
Lines 15581 15592 +11
======================================
+ Hits 13483 13488 +5
- Misses 2098 2104 +6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we please add tests for validation and training as well? And maybe an fsdp test? Also some notes on caveats might be good to add to the sampler docs
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed) | ||
|
||
len_dataset = len(self.dataset) # type: ignore[arg-type] | ||
if not self.drop_last and len_dataset % self.num_replicas != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the issue with this that it wouldn't necessarily work with validation, since not all ranks would reach the same distributed function calls and therefore time out which would kill the entire process. Also this would never work with FSDP, since some ranks have a batch more and for fsdp, not all processes would reach the forward syncing points also resulting in timeouts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that in the context of Lightning this wouldn't work well, as it does not support Join (Lightning-AI/pytorch-lightning#3325)
FSDP also doesn't support join afaik (pytorch/pytorch#64683)
But outside Lightning, and taking FSDP out of the equation, I agree this can work and is a good utility to have IMO. It also suits the metric design well, since synchronization is only necessary when all processes have finished collecting their statistics and .compute()
can be called.
calling @awaelchli for distributed review :) |
You are right that we need to test this feature better to clearly state the limitations. |
Converted to draft until better tested. |
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed) | ||
|
||
len_dataset = len(self.dataset) # type: ignore[arg-type] | ||
if not self.drop_last and len_dataset % self.num_replicas != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that in the context of Lightning this wouldn't work well, as it does not support Join (Lightning-AI/pytorch-lightning#3325)
FSDP also doesn't support join afaik (pytorch/pytorch#64683)
But outside Lightning, and taking FSDP out of the equation, I agree this can work and is a good utility to have IMO. It also suits the metric design well, since synchronization is only necessary when all processes have finished collecting their statistics and .compute()
can be called.
|
||
""" | ||
|
||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Lightning we have a very similar class: https://github.com/Lightning-AI/lightning/blob/fbdbe632c67b05158804b52f4345944781ca4f07/src/lightning/pytorch/overrides/distributed.py#L194
I think the main difference is that yours respects the setting drop_last. I'm not sure why we have the __iter__
overridden there but if you are interested you can compare the two.
metric_class=metric_class, | ||
), | ||
range(NUM_PROCESSES), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In addition, a unit test for just the sampler alone could be useful, one that doesn't launch processes (not needed) but rather just assert the indices returned on each rank match the expectation, e.g.:
sampler = EvaluationDistributedSampler(dataset, rank=0, num_replicas=3, drop_last=...)
assert list(iter(sampler)) == ....
sampler = EvaluationDistributedSampler(dataset, rank=2, num_replicas=3, drop_last=...)
assert list(iter(sampler)) == ....
and so on to test all edge cases.
Co-authored-by: Adrian Wälchli <[email protected]>
@SkafteNicki, what is missing here to make it land? 🐿️ |
@SkafteNicki this seems to be pending for a while, shall we continue? 🐰 |
Closing as while this would be a nice to have it is probably too limited in scope to provide enough value + it is not a core feature of torchmetrics. |
What does this PR do?
Fixes #1338
The original issue is about if we should implement a join context such that metrics could be evaluated on uneven number of samples in distributed settings. Just to remind, we normally discourage users from evaluating in distributed because the default distributed sampler from Pytorch will add additional samples to make all processes do even work, which messes with results.
After investigating this issue, it seems that we do not need a join context at all due to the custom synchronization we have for metrics. To understand this we need to look at the two different states we can have: tensor state and list of tensor states.
[t_01, t_02]
and rank 1 state is a list of one tensor[t_11]
(rank 0 have seen one more batch than rank 1). We list states are encountered internally we make sure to concatenate the states into one tensor to not need to callallgather
for each tensor in the listtorchmetrics/src/torchmetrics/metric.py
Lines 418 to 419 in 879595d
such after this each state is a single tensor
t_0
andt_1
but clearlyt_0.shape != t_1.shape
. Again, internally we deal with this by padding to same size and then doing a all gather:torchmetrics/src/torchmetrics/utilities/distributed.py
Lines 136 to 148 in 879595d
Thus in both cases, even if one rank sees more samples/batches, we still do the same number of distributed operations per rank, which should mean that everything works.
To highlight this feature of TM this PR does a couple of things:
EvaluationDistributedSampler
that does not add extra samplers. Thus, users can use this as a drop in replacement for anyDistributedSampler
if they want to do proper distributed evaluation (else they just need to secure that number of samples are even divisible by the number of processes).Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃
📚 Documentation preview 📚: https://torchmetrics--1886.org.readthedocs.build/en/1886/