-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multigpu Feature #3769
Multigpu Feature #3769
Changes from 4 commits
b317cbf
dbaaa49
c117185
c8701bd
1e5302c
e0a8b59
1c63771
ddc2587
988a4a6
7389ea9
ebd0cf1
9b16750
a2dfabb
a02a68d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Design Doc: Multi-GPU support in Operation Graph | ||
|
||
## Abstract | ||
|
||
This Design Doc refers to the multi-GPU feature in paddle. We propose an approach to support multi-GPU both on a single machine and multiple machines. Every device only run sub-graphs which our framework issued. We use `Broadcast`, `Allreduce` operators to join different device sub-graph to the whole graph. | ||
|
||
|
||
|
||
## Motivation | ||
|
||
Paddle supports training with multiple CPUs and GPUs, refer to different physical devices. We need to support multi-GPU training in parallel for acceleration, in detail, there are two aspects. | ||
|
||
- GPU Data Parallelism | ||
|
||
Suppose to we have `n`GPUs, every GPU has `1/n`part of training data, and store a complete model in GPU memory. | ||
|
||
- GPU Model Parallelism | ||
|
||
every GPU have part of a complete model in GPU memory. | ||
|
||
At the beginning of training, the framework needs to issue the same sub-graph to every GPU in Data Parallelism, or different sub-graph in Model Parallelism. | ||
|
||
During training, we need the operations of peer to peer copy between different GPUs, aggregating gradients/parameters from GPUs, and broadcasting parameters to GPUs. Every GPU only need to run the sub-graph with correct place information. | ||
|
||
Besides, it needs interfaces to synchronize model update with each other, and issue/merge model from different GPU Cards. | ||
|
||
## Implementation | ||
|
||
As mentioned above, we summarise that several kinds of operators are needed. Currently, we need to issue parameters to different GPUs, named it with Broadcast operator. And also synchronize parameters between GPUs, called it with AllReduce. | ||
|
||
### Graph Converter | ||
|
||
To be compatible with parameter server design doc, the graph converter converts the user defined operation graph into sub-graphs to be executed on different devices. | ||
|
||
1. The user-defined operator graph will be partitioned into sub-graph. | ||
|
||
2. Control operators between GPUs will be inserted into the graph. | ||
|
||
*Broadcast, AllReduce in a single machine. And Broadcast, AllReduce, Send, Recv in multiple machines* | ||
|
||
<img src="images/multigpu_before_convert.png" width="300"/> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Picture missing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's so Weird. fixed |
||
|
||
After convert, the graph as shows | ||
|
||
<img src="images/multigpu_allreduce.png" width="1000"/> | ||
|
||
Operators are added to the sub-graphs. Every GPU assigned a role of `rank0`, `rank1` etc. | ||
|
||
- **Broadcast**. Broadcast operator distribute initialized parameter to all the GPUs from the GPU who owns it. e.g. from`rank0` GPU. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two operators are part of the graph, please draw the dependency more clearly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
- **Allreduce**. Allreduce operator synchronizes parameters/gradients between GPUs. AllReduce implemented in the Ring-Based communicating method, avoid of the bottle neck in a single GPU. | ||
|
||
These two operators need the Multi-GPU context support. | ||
|
||
Need to notice that Allreduce operator force GPUs synchronized at that point. Every device only need runs sub-graph in a loop style forever, the whole training process in asynchronous or synchronous mode depends on the Allreduce point in the graph. | ||
|
||
For the simplest implement, when each GPU compute the gradient of `W`, followed with a `AllReduce` operator, accumulate the `dW` to full batch of data, then run the optimize process individually and apply the gradient to its `W`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move the "Implementation" section here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so. Graph converter is also part of our implement. To be unified with dist_train.md, we put it at an independent paragraph to make the document more clear. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. |
||
|
||
In fact, in the way of every GPU optimized full batch of data, wasted (n-1) GPU compute resources. We will enhance it in the next stage. | ||
|
||
### Benefits | ||
|
||
- can easily move the optimize sub-graph to parameter server, multi-GPU feature can be compatible with distributed support design. | ||
- easily plug-in with NCCL2 library. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reference the NCCL library URL, please. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
- GPU Model parallelism becomes easier to implement. we only need to replace different GPU's sub-graph with different part of the whole graph. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you mentioned "Send, Recv" can you please add a reference link to these design docs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reminding! Done.