Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the PyTorch CUDA implementation for Criss Cross Attention #1088

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 89 additions & 101 deletions mmcv/ops/csrc/cc_attention_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,22 @@ __global__ void ca_forward_kernel(const T *t, const T *f, T *weight, int num,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int z = blockIdx.z;

if (x < width && y < height && z < height + width - 1) {
for (int batch = 0; batch < num; ++batch) {
for (int plane = 0; plane < chn; ++plane) {
T _t = t[(batch * chn + plane) * sp + y * width + x];

if (z < width) {
int i = z;
T _f = f[(batch * chn + plane) * sp + y * width + i];
weight[(batch * len + i) * sp + y * width + x] += _t * _f;
} else {
int i = z - width;
int j = i < y ? i : i + 1;

T _f = f[(batch * chn + plane) * sp + j * width + x];
weight[(batch * len + width + i) * sp + y * width + x] += _t * _f;
}
int z = blockIdx.z % len;
int batch = blockIdx.z / len;

if (x < width && y < height) {
for (int plane = 0; plane < chn; ++plane) {
T _t = t[(batch * chn + plane) * sp + y*width + x];

if (z < width) {
int i = z;
T _f = f[(batch * chn + plane) * sp + y*width + i];
weight[(batch * len + i) * sp + y*width + x] += _t*_f;
} else {
int i = z - width;
int j = i<y ? i : i+1;
T _f = f[(batch * chn + plane) * sp + j*width + x];
weight[(batch * len + width + i) * sp + y*width + x] += _t*_f;
Comment on lines +22 to +32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add space between variables (y, width ...) and ops(*, + ...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'll fix it.

}
}
}
Expand All @@ -44,23 +42,22 @@ __global__ void ca_backward_kernel_t(const T *dw, const T *t, const T *f, T *dt,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int plane = blockIdx.z;

if (x < width && y < height && plane < chn) {
for (int batch = 0; batch < num; ++batch) {
for (int i = 0; i < width; ++i) {
T _dw = dw[(batch * len + i) * sp + y * width + x];
T _f = f[(batch * chn + plane) * sp + y * width + i];
dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;
int j = i < y ? i : i - 1;
int plane = blockIdx.z % chn;
int batch = blockIdx.z / chn;

if (x < width && y < height ) {
for (int i = 0; i < width; ++i) {
float _dw = dw[(batch * len + i) * sp + y*width + x];
float _f = f[(batch * chn + plane) * sp + y*width + i];
dt[(batch * chn + plane) * sp + y*width + x] += _dw * _f;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;
int j = i<y ? i : i-1;

T _dw = dw[(batch * len + width + j) * sp + y * width + x];
T _f = f[(batch * chn + plane) * sp + i * width + x];
dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f;
}
float _dw = dw[(batch * len + width + j) * sp + y*width + x];
float _f = f[(batch * chn + plane) * sp + i*width + x];
Comment on lines +50 to +59
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for using float instead of T here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'll replace float with T.

dt[(batch * chn + plane) * sp + y*width + x] += _dw * _f;
}
}
}
Expand All @@ -72,23 +69,22 @@ __global__ void ca_backward_kernel_f(const T *dw, const T *t, const T *f, T *df,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int plane = blockIdx.z;

if (x < width && y < height && plane < chn) {
for (int batch = 0; batch < num; ++batch) {
for (int i = 0; i < width; ++i) {
T _dw = dw[(batch * len + x) * sp + y * width + i];
T _t = t[(batch * chn + plane) * sp + y * width + i];
df[(batch * chn + plane) * sp + y * width + x] += _dw * _t;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;
int j = i > y ? y : y - 1;
int plane = blockIdx.z % chn;
int batch = blockIdx.z / chn;

if (x < width && y < height) {
for (int i = 0; i < width; ++i) {
T _dw = dw[(batch * len + x) * sp + y*width + i];
T _t = t[(batch * chn + plane) * sp + y*width + i];
df[(batch * chn + plane) * sp + y*width + x] += _dw * _t;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;
int j = i>y ? y : y-1;

T _dw = dw[(batch * len + width + j) * sp + i * width + x];
T _t = t[(batch * chn + plane) * sp + i * width + x];
df[(batch * chn + plane) * sp + y * width + x] += _dw * _t;
}
T _dw = dw[(batch * len + width + j) * sp + i*width + x];
T _t = t[(batch * chn + plane) * sp + i*width + x];
df[(batch * chn + plane) * sp + y*width + x] += _dw * _t;
}
}
}
Expand All @@ -100,24 +96,23 @@ __global__ void ca_map_forward_kernel(const T *weight, const T *g, T *out,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int plane = blockIdx.z;

if (x < width && y < height && plane < chn) {
for (int batch = 0; batch < num; ++batch) {
for (int i = 0; i < width; ++i) {
T _g = g[(batch * chn + plane) * sp + y * width + i];
T _w = weight[(batch * len + i) * sp + y * width + x];
out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;
int plane = blockIdx.z % chn;
int batch = blockIdx.z / chn;
T res = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has this res been used?

if (x < width && y < height) {
for (int i = 0; i < width; ++i) {
T _g = g[(batch * chn + plane) * sp + y * width + i];
T _w = weight[(batch * len + i) * sp + y * width + x];
out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;

int j = i < y ? i : i - 1;
int j = i < y ? i : i - 1;

T _g = g[(batch * chn + plane) * sp + i * width + x];
T _w = weight[(batch * len + width + j) * sp + y * width + x];
out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
}
T _g = g[(batch * chn + plane) * sp + i * width + x];
T _w = weight[(batch * len + width + j) * sp + y * width + x];
out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
}
}
}
Expand All @@ -130,25 +125,23 @@ __global__ void ca_map_backward_kernel_w(const T *dout, const T *weight,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int z = blockIdx.z;

if (x < width && y < height && z < height + width - 1) {
for (int batch = 0; batch < num; ++batch) {
for (int plane = 0; plane < chn; ++plane) {
T _dout = dout[(batch * chn + plane) * sp + y * width + x];

if (z < width) {
int i = z;
T _g = g[(batch * chn + plane) * sp + y * width + i];
dw[(batch * len + i) * sp + y * width + x] += _dout * _g;
} else {
int i = z - width;
int j = i < y ? i : i + 1;

T _g = g[(batch * chn + plane) * sp + j * width + x];
dw[(batch * len + width + i) * sp + y * width + x] += _dout * _g;
}
}
int z = blockIdx.z % len;
int batch = blockIdx.z / len;

if (x < width && y < height) {
int widx = (batch * len + z) * sp + y*width + x;
int dout_idx = batch * chn * sp + y * width + x;
int gidx = batch * chn * sp;
if (z < width) {
gidx += y * width + z;
} else {
int j = z - width;
j = j < y ? j : j + 1;
gidx += j * width + x;
}
for(int plane = 0; plane < chn; plane ++){
dw[widx] += dout[dout_idx + plane * sp] * g[gidx+plane*sp];
Comment on lines +133 to +144
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part looks great! Can we do the very same in ca_forward_kernel?
By the way, it would be better to use const int for these variables which would never be changed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thanks for good advice!

}
}
}
Expand All @@ -161,25 +154,20 @@ __global__ void ca_map_backward_kernel_g(const T *dout, const T *weight,
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int plane = blockIdx.z;

if (x < width && y < height && plane < chn) {
for (int batch = 0; batch < num; ++batch) {
for (int i = 0; i < width; ++i) {
T _dout = dout[(batch * chn + plane) * sp + y * width + i];
T _w = weight[(batch * len + x) * sp + y * width + i];
dg[(batch * chn + plane) * sp + y * width + x] += _dout * _w;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;
int j = i > y ? y : y - 1;
int plane = blockIdx.z % chn;
int batch = blockIdx.z / chn;
int index = (batch * chn + plane) * sp + y*width + x;

T _dout = dout[(batch * chn + plane) * sp + i * width + x];
T _w = weight[(batch * len + width + j) * sp + i * width + x];
dg[(batch * chn + plane) * sp + y * width + x] += _dout * _w;
}
if (x < width && y < height) {
for (int i = 0; i < width; ++i) {
dg[index] += dout[(batch * chn + plane) * sp + y*width + i] * weight[(batch * len + x) * sp + y*width + i];
}
int j = 0;
for (int i = 0; i < height; ++i) {
if (i == y) continue;
j = i > y ? y : y - 1;
Comment on lines +165 to +168
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use const int j = i > y ? y : y - 1; inside the for loop.

dg[index] += dout[(batch * chn + plane) * sp + i * width + x] * weight[(batch * len + width + j) * sp + i * width + x];
}
}
}

#endif // CC_ATTENTION_CUDA_KERNEL_CUH
15 changes: 8 additions & 7 deletions mmcv/ops/csrc/pytorch/cc_attention_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f,
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
int d3 = h + w;
dim3 blocks(d1, d2, d3);
int d3 = h + w - 1;
dim3 blocks(d1, d2, d3 * n);

AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_forward", [&] {
ca_forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
Expand Down Expand Up @@ -53,7 +53,7 @@ void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t,
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
int d3 = c;
int d3 = c * n;
dim3 blocks(d1, d2, d3);

AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_backward_kernel_t", [&] {
Expand Down Expand Up @@ -90,7 +90,7 @@ void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g,
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
int d3 = c;
int d3 = c * n;
dim3 blocks(d1, d2, d3);

AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_forward", [&] {
Expand Down Expand Up @@ -119,8 +119,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
int d3 = h + w;
dim3 blocks(d1, d2, d3);
int d3 = h + w - 1;
dim3 blocks(d1, d2, d3 * n);

AT_DISPATCH_FLOATING_TYPES(
weight.scalar_type(), "ca_map_backward_kernel_w", [&] {
Expand All @@ -130,7 +130,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
g.contiguous().data_ptr<scalar_t>(),
dw.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});

d3 = c * n;
blocks = dim3(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_backward_kernel_g", [&] {
ca_map_backward_kernel_g<scalar_t><<<blocks, threads, 0, stream>>>(
dout.contiguous().data_ptr<scalar_t>(),
Expand Down