Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

What is the use of "AllReduce"? #21

Open
yyk-wew opened this issue Mar 9, 2023 · 0 comments
Open

What is the use of "AllReduce"? #21

yyk-wew opened this issue Mar 9, 2023 · 0 comments

Comments

@yyk-wew
Copy link

yyk-wew commented Mar 9, 2023

Hello. Thank you for your great work!

I have some questions about the "AllReduce" class defined here.

msn/src/utils.py

Lines 226 to 241 in 4388dc1

class AllReduce(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
if (
dist.is_available()
and dist.is_initialized()
and (dist.get_world_size() > 1)
):
x = x.contiguous() / dist.get_world_size()
dist.all_reduce(x)
return x
@staticmethod
def backward(ctx, grads):
return grads

And it is used in gathering probs when computing me-max regularization.

msn/src/losses.py

Lines 70 to 72 in 4388dc1

if me_max:
avg_probs = AllReduce.apply(torch.mean(probs, dim=0))
rloss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs)))

I wonder why not use "dist.all_reduce(x)" directly. It seems that using "AllReduce" multiply the gradient by "world_size" times.
I want to know whether i am correct and why this makes sense.

Thx!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant