Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bp of shared parameters and experts #161

Open
a157801 opened this issue Jun 14, 2022 · 7 comments
Open

bp of shared parameters and experts #161

a157801 opened this issue Jun 14, 2022 · 7 comments
Labels
question Further information is requested

Comments

@a157801
Copy link

a157801 commented Jun 14, 2022

The ddp in pytorch can not distinguish experts and other shared parameters. And experts may be updated with shared gradient.
The TutelDistributedOptimizer seems to be an implementation of zero, which does not affect the graident. How does tutel deal with the problem?

@ghostplant ghostplant added the question Further information is requested label Jun 14, 2022
@ghostplant
Copy link
Contributor

Yes, TutelDistributedOptimizer is a replacement of Pytorch DDP in that example (helloworld_ddp_tutel) to make the whole model sychronization transparent.

TutelDistributedOptimizer not only implements ZeRO optimization, but also leverages built-in mask (_tutel_expert) to distinguish whether a parameter is shared or from the creation of tutel.moe.moe_layer.

Note that TutelDistributedOptimizer only treats parameters created by tutel.moe.moe_layer to be expert parameters. If the model never uses tutel.moe.moe_layer, there is no difference with Pytorch DDP (expect TutelDistributedOptimizer includes ZeRO feature).

@a157801
Copy link
Author

a157801 commented Jun 14, 2022

Thank you for your answer.
I notice that _tutel_expert flag is used to split the parameters. But it seems that gradient of experts with _tutel_expert will also be allreduced by ddp. The _tutel_expert flag indicates these parameters are experts and will not be splitted on different gpus, but does not controll the allreduce operation.

@ghostplant
Copy link
Contributor

ghostplant commented Jun 15, 2022

To use TutelDistributedOptimizer which has parameter synchronization included, you should no longer warp the model with DistributedDataParallel.

@a157801
Copy link
Author

a157801 commented Jun 17, 2022

I notice the code in swin-transformer repo(https://github.com/microsoft/Swin-Transformer/blob/main/main_moe.py), which uses pytorch optimizer and ddp to train these moe models. Maybe there is something wrong. Thanks a lot.

@ghostplant
Copy link
Contributor

It is a version that manually distinguish parameter types, which follows helloworld_ddp.py

@a157801
Copy link
Author

a157801 commented Jun 17, 2022

Does it work by setting skip_allreduce as true in the scan function?

@ghostplant
Copy link
Contributor

ghostplant commented Jun 17, 2022

To use tutel moe in Pytorch DDP backend, you need to not only set skip_allreduce as true in the moe scan function, but also recollect parameters with those masks, and tell DDP to skip synchronizing them by: https://github.com/microsoft/tutel/blob/main/tutel/examples/helloworld_ddp.py#L92. Otherwise, Pytorch DDP won't know they are expert parameters, so they'll be synchronized unexpectedly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants