Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multigpu Feature #3769

Merged
merged 14 commits into from
Dec 14, 2017
Merged

Multigpu Feature #3769

merged 14 commits into from
Dec 14, 2017

Conversation

dzhwinter
Copy link
Contributor

@dzhwinter dzhwinter commented Aug 30, 2017

here is better for review.
fix #3651

@dzhwinter dzhwinter changed the title Multigpu Multigpu Feature Aug 30, 2017

- GPU Model Parallelism

every GPU have `1/n` part of training data, and only have part of a complete model in GPU memory.
Copy link
Contributor

@helinwang helinwang Sep 4, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

模型并行的数据貌似全部都在一个GPU?把模型切成了n份,貌似输入层大部分情况会被切到一个GPU上?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right. fixed.


Besides, it needs interfaces to synchronize model update with each other, and issue/merge model from different GPU Cards.

## Implement
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implement -> Implementation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


Operators are added to the sub-graphs. Every GPU assigned a role of `rank0`, `rank1` etc.

- **Broadcast**. Broadcast operator distribute initialized parameter to all the GPUs from the GPU who owns it. e.g. from`rank0` GPU.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two operators are part of the graph, please draw the dependency more clearly.
If the dependency is clear, reader should be able to understand what is the target for the graph initialization, and what is the target for each training step.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


These two operators need the Multi-GPU context support.

Need to notice that Allreduce operator force GPUs synchronized at that point. Every device only need runs sub-graph in a loop style forever, the whole training process in synchronizing or synchronize style depends on the Allreduce point in the graph.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't quite understand what does "synchronizing or synchronize style" mean :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo fixed.


Need to notice that Allreduce operator force GPUs synchronized at that point. Every device only need runs sub-graph in a loop style forever, the whole training process in asynchronous or synchronous mode depends on the Allreduce point in the graph.

For the simplest implement, when each GPU compute the gradient of `W`, followed with a `AllReduce` operator, accumulate the `dW` to full batch of data, then run the optimize process individually and apply the gradient to its `W`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the "Implementation" section here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Graph converter is also part of our implement. To be unified with dist_train.md, we put it at an independent paragraph to make the document more clear.
I change the first sentence for avoiding ambiguity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.


*Broadcast, AllReduce in a single machine. And Broadcast, AllReduce, Send, Recv in multiple machines*

<img src="images/multigpu_before_convert.png" width="300"/>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Picture missing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's so Weird. fixed


2. Control operators between GPUs will be inserted into the graph.

*Broadcast, AllReduce in a single machine. And Broadcast, AllReduce, Send, Recv in multiple machines*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you mentioned "Send, Recv" can you please add a reference link to these design docs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminding! Done.

### Benefits

- can easily move the optimize sub-graph to parameter server, multi-GPU feature can be compatible with distributed support design.
- easily plug-in with NCCL2 library.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reference the NCCL library URL, please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@helinwang
Copy link
Contributor

helinwang commented Sep 13, 2017

  • We need a "To Be Decided" section describing "Explicit between send / recv vs. implicit copy on use".

  • At the beginning of training, the framework needs to issue the same sub-graph to every GPU in Data Parallelism, or different sub-graph in Model Parallelism.

    "same sub-graph" is not necessarily true here. Maybe change to "At the beginning of training, The framework will issue a sub-graph to every GPU"

  • and issue/merge model from

    what does "issue" mean?

  • These two operators need the Multi-GPU context support.

    Do we want to allow an OP stay on different devices?

  • Every device only need runs sub-graph in a loop style forever

    This depends on what is user's eval target, if it's a while OP, it will loop forever, otherwise it will not.

  • In fact, in the way of every GPU optimized full batch of data, wasted (n-1) GPU compute resources. We will enhance it in the next stage.

    Do we need to enhance it? I think it wastes computation, but saves one round of communication.

@typhoonzero
Copy link
Contributor

Same questions with @helinwang :

We mentioned both GPU data parallelism and model parallelism, and it seems that we are going to implement GPU data parallelism first. Need to point out this?

@dzhwinter
Copy link
Contributor Author

It should be an NCCL based design doc only. Thank you for the reviewing, guys!

@dzhwinter
Copy link
Contributor Author

The Data parallelism and Model parallelism, the confusion part has been removed, and add the Allreduce section detail.


As it shown in the picture, when each GPU compute the gradient of `W`, followed with a `AllReduce` operator, accumulate the `dW` to full batch of data, then run the optimize process individually and apply the gradient to its `W`.

- **AllReduce2**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to decide on one all reduce OP. supporting two different OP for the same purpose is just too much labor.

I am more leaning towards implementing our own AllReduce, since AllReduce2 adds one more dependency: NCCL2, and NCCL2 is closed sourced.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AllReduce2 is a composed operator write by hand. We only use Reduce operator to implement the AllReduce2.

Now we already have changed to NCCL2 in paddle. Not one more dependency.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks.

Since there is already AllReduce, do we need another AllReduce2? For the reasons mentioned above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we only need the AllReduce2, actually. I write down the AllReduce2 just for avoiding people to misunderstand with NCCL built-in AllReduce.

Should I remove the AllReduce description and leave AllReduce2 alone?

Copy link
Contributor

@helinwang helinwang Dec 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you think about calling it AllReduce? It's a PaddlePaddle OP, and there is no AllReduce1, so we probably should not name it as AllReduce2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

- **AllReduce2**
If we use the NCCL2 AllReduce primitive, every GPU optimized full batch of data, wasted (n-1) GPU compute resources. In addition, AllReduce will only utilize the communicate resource during synchronization, then update the gradient will be a seperated phase. In fact, we can amortize the update gradient time cost into the communicating phase.
- Every parameter has its root card. That card will call **Reduce** operator and collect the gradients from GPUs.
- The whole model's parameter will be hashed to different root card, ensure the load balance between GPUs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a personal question, should we add a field device_id in Var in the protobuf or NCCL would do this by itself?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still a controversial topic in our design, it's not determined by NCCL. So we can leave that discussion in the parallel with multi-device topic.


- **AllReduce**
Need to note that our AllReduce operator is a ring-base AllReduce implementation. If we use the NCCL2 AllReduce primitive, every GPU optimized full batch of data, wasted (n-1) GPU compute resources. In addition, NCCL2 built-in AllReduce will only utilize the communicating resource during synchronization, then update the gradient will be a subsequent phase. In fact, we can amortize the update gradient time cost into the communicating phase. The process is
1. Every parameter has its root card. That card will responsible for aggregating the gradients from GPUs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could introduce how to distribute the parameters(round-robin, hash or user-specified)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that's another problem coupled with parallel.do , @tonyyang-svail is working on it.


## Motivation

NCCL is a NVIDIA library support Multi-GPU communicating and optimized for NVIDIA GPUs, it provides routines such as all-gather, all-reduce, broadcast, reduce, reduce-scatter, that can achieve high bandwidth over PCIe and NVLink high-speed interconnect. [NCCL](https://developer.nvidia.com/nccl). With NCCL library, we can easily accelerate the training in parallel.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... PCIe and NVLink high-speed interconnect. NCCL. With NCCL library, we can easily accelerate the training in parallel.

Maybe move the linker to the front of the sentence?

NCCL is a NVIDIA library support Multi-GPU communicating and optimized for NVIDIA GPUs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


### Graph Converter

To be compatible with [parameter server design doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/ops/dist_train.md), the graph converter converts the user defined operation graph into sub-graphs to be executed on different devices.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

graph converter => transpiler

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

As it shown in the picture, when each GPU compute the gradient of `W`, followed with a `AllReduce` operator, accumulate the `dW` to full batch of data, then run the optimize process individually and apply the gradient to its `W`.

- **AllReduce**
Need to note that our AllReduce operator is a ring-base AllReduce implementation. If we use the NCCL2 AllReduce primitive, every GPU optimized full batch of data, wasted (n-1) GPU compute resources. In addition, NCCL2 built-in AllReduce will only utilize the communicating resource during synchronization, then update the gradient will be a subsequent phase. In fact, we can amortize the update gradient time cost into the communicating phase. The process is
Copy link
Member

@jacquesqiao jacquesqiao Dec 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NCCL2 also support ring-base AllReduce. see https://github.com/PaddlePaddle/Paddle/wiki/NCCL2-Survey

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个并不一样,我们需要的不仅是ring-based AllReduce. NCCL2 AllReduce只支持sum, max这类简单操作,我们需要在其中做优化。

Copy link
Contributor

@Yancey1989 Yancey1989 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, and maybe @helinwang would review this PR again.

@helinwang helinwang merged commit c52a0bd into PaddlePaddle:develop Dec 14, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multi-GPU support
9 participants