-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize the rowwise add function. #7047
Conversation
template <typename T> | ||
__global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int64_t height, | ||
int64_t width) { | ||
int64_t num = height * width; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num
can be passed in as a parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thank you!
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; | ||
i += blockDim.x * gridDim.x) { | ||
int h = i / width; | ||
int w = i % width; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The integer modulo (%) and division operations are expensive in GPU hardware.
The division seems can be replaced by the multiplication. And modulo (%) can be replaced by subtraction and multiplication.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thank you!
01bc012
to
5c94725
Compare
__global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int64_t width, | ||
int64_t num) { | ||
T tmp = 1.0 / width; | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to change the type of num
to int. Otherwise, there is a comparison of int
data and int64_t
data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks!
58efe9b
to
1936738
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Fix #6842
Mainly for the broadcast in Eigen. The time changes after optimization are as follows:
Experiments Env:
Total time of 2 epoc: