Skip to content

Commit

Permalink
[Cherry-pick][OpenCl]Fix elementwise opencl bug (#7363)
Browse files Browse the repository at this point in the history
* [OpenCL] Add elementwise common broadcast and compute test=develop

* fix elementwise bug test=develop
  • Loading branch information
sprouteer authored Oct 22, 2021
1 parent 770c3e1 commit 32c1b3e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,19 @@ __kernel void broadcast_elementwise_common(
in0 = READ_IMG_TYPE(
CL_DTYPE_CHAR, input_x, SAMPLER, (int2)(cur_index.x * 4, cur_index.y));

if (cur_index.x * 4 + 1 < input_nhwc4.w * 4) {
if (cur_index.x * 4 + 1 < bias_width) {
in1 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input_x,
SAMPLER,
(int2)(cur_index.x * 4 + 1, cur_index.y));
}
if (cur_index.x * 4 + 2 < input_nhwc4.w * 4) {
if (cur_index.x * 4 + 2 < bias_width) {
in2 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input_x,
SAMPLER,
(int2)(cur_index.x * 4 + 2, cur_index.y));
}
if (cur_index.x * 4 + 3 < input_nhwc4.w * 4) {
if (cur_index.x * 4 + 3 < bias_width) {
in3 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input_x,
SAMPLER,
Expand All @@ -100,19 +100,19 @@ __kernel void broadcast_elementwise_common(
in0 = READ_IMG_TYPE(
CL_DTYPE_CHAR, input_x, SAMPLER, (int2)(cur_index.y, cur_index.x * 4));

if (cur_index.x * 4 + 1 < input_nhwc4.z * input_nhwc4.w * 4) {
if (cur_index.x * 4 + 1 < bias_width) {
in1 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input_x,
SAMPLER,
(int2)(cur_index.y, cur_index.x * 4 + 1));
}
if (cur_index.x * 4 + 2 < input_nhwc4.z * input_nhwc4.w * 4) {
if (cur_index.x * 4 + 2 < bias_width) {
in2 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input_x,
SAMPLER,
(int2)(cur_index.y, cur_index.x * 4 + 2));
}
if (cur_index.x * 4 + 3 < input_nhwc4.z * input_nhwc4.w * 4) {
if (cur_index.x * 4 + 3 < bias_width) {
in3 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input_x,
SAMPLER,
Expand Down Expand Up @@ -140,19 +140,19 @@ __kernel void broadcast_elementwise_common(
SAMPLER,
(int2)(tmp_c4 * input_nhwc4.y + tmp_w, tmp_h));

if (cur_index.x + 1 < input_nhwc4.x * input_nhwc4.y) {
if (tmp_h + 1 < bias_width) {
in1 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input_x,
SAMPLER,
(int2)(tmp_c4 * input_nhwc4.y + tmp_w, tmp_h + 1));
}
if (cur_index.x + 2 < input_nhwc4.x * input_nhwc4.y) {
if (tmp_h + 2 < bias_width) {
in2 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input_x,
SAMPLER,
(int2)(tmp_c4 * input_nhwc4.y + tmp_w, tmp_h + 2));
}
if (cur_index.x + 3 < input_nhwc4.x * input_nhwc4.y) {
if (tmp_h + 3 < bias_width) {
in3 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input_x,
SAMPLER,
Expand Down Expand Up @@ -235,19 +235,19 @@ __kernel void broadcast_elementwise_common(
in0 = READ_IMG_TYPE(
CL_DTYPE_CHAR, input_y, SAMPLER, (int2)(cur_index.y, cur_index.x * 4));

if (cur_index.x * 4 + 1 < bias_nhwc4.z * bias_nhwc4.w * 4) {
if (cur_index.x * 4 + 1 < bias_width) {
in1 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input_y,
SAMPLER,
(int2)(cur_index.y, cur_index.x * 4 + 1));
}
if (cur_index.x * 4 + 2 < bias_nhwc4.z * bias_nhwc4.w * 4) {
if (cur_index.x * 4 + 2 < bias_width) {
in2 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input_y,
SAMPLER,
(int2)(cur_index.y, cur_index.x * 4 + 2));
}
if (cur_index.x * 4 + 3 < bias_nhwc4.z * bias_nhwc4.w * 4) {
if (cur_index.x * 4 + 3 < bias_width) {
in3 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input_y,
SAMPLER,
Expand Down Expand Up @@ -275,19 +275,19 @@ __kernel void broadcast_elementwise_common(
SAMPLER,
(int2)(tmp_c4 * bias_nhwc4.y + tmp_w, tmp_h));

if (cur_index.x + 1 < bias_nhwc4.x * bias_nhwc4.y) {
if (tmp_h + 1 < bias_width) {
in1 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input_y,
SAMPLER,
(int2)(tmp_c4 * bias_nhwc4.y + tmp_w, tmp_h + 1));
}
if (cur_index.x + 2 < bias_nhwc4.x * bias_nhwc4.y) {
if (tmp_h + 2 < bias_width) {
in2 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input_y,
SAMPLER,
(int2)(tmp_c4 * bias_nhwc4.y + tmp_w, tmp_h + 2));
}
if (cur_index.x + 3 < bias_nhwc4.x * bias_nhwc4.y) {
if (tmp_h + 3 < bias_width) {
in3 = READ_IMG_TYPE(CL_DTYPE_CHAR,
input_y,
SAMPLER,
Expand Down
2 changes: 1 addition & 1 deletion lite/kernels/opencl/elementwise_image_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ class ElementwiseImageCompute : public KernelLite<TARGET(kOpenCL),

int inputx_broadcast_c_flag = (x_nchw_[1] == 1) ? 1 : 0;
int inputy_broadcast_c_flag = (y_nchw_[1] == 1) ? 1 : 0;
int bias_width = y_nchw_[1];
int bias_width = out_nchw_[1];

if (y_dims_ == x_dims_) {
cl_int status = kernel_.setArg(0, *x_img);
Expand Down
4 changes: 4 additions & 0 deletions lite/kernels/opencl/elementwise_image_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,10 @@ void test_elementwise_all_dim_data_gpu() {
c = randint(1, 40);
h = randint(1, 40);
w = randint(1, 40);
n = 2;
c = 3;
h = 4;
w = 5;
std::vector<bool> xy_swap_flags{false, true};
for (auto xy_swap_flag : xy_swap_flags) {
RunElementwiseCommonSize<float>({n, c, h, w},
Expand Down

0 comments on commit 32c1b3e

Please sign in to comment.