Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix roi_pool op bug #10700

Merged
merged 8 commits into from
May 18, 2018
Merged
40 changes: 23 additions & 17 deletions paddle/fluid/operators/roi_pool_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ __global__ void GPUROIPoolForward(
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % channels;
int n = i / pooled_width / pooled_height / channels;

const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = roi_batch_id_data[n];
Expand All @@ -52,14 +52,19 @@ __global__ void GPUROIPoolForward(

int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));

int hstart = static_cast<int>(floor(static_cast<double>(ph) *
static_cast<double>(roi_height) /
static_cast<double>(pooled_height)));
int wstart = static_cast<int>(floor(static_cast<double>(pw) *
static_cast<double>(roi_width) /
static_cast<double>(pooled_width)));
int hend = static_cast<int>(ceil(static_cast<double>(ph + 1) *
static_cast<double>(roi_height) /
static_cast<double>(pooled_height)));
int wend = static_cast<int>(ceil(static_cast<double>(pw + 1) *
static_cast<double>(roi_width) /
static_cast<double>(pooled_width)));
hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height);
wstart = min(max(wstart + roi_start_w, 0), width);
Expand All @@ -79,9 +84,9 @@ __global__ void GPUROIPoolForward(
}
}
}
output_data[index] = maxval;
output_data[i] = maxval;
if (argmax_data) {
argmax_data[index] = maxidx;
argmax_data[i] = maxidx;
}
}
}
Expand All @@ -96,10 +101,10 @@ __global__ void GPUROIPoolBackward(
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int pw = i % pooled_width;
int ph = (i / pooled_width) % pooled_height;
int c = (i / pooled_width / pooled_height) % channels;
int n = i / pooled_width / pooled_height / channels;

int roi_batch_ind = roi_batch_id_data[n];
int input_offset = (roi_batch_ind * channels + c) * height * width;
Expand Down Expand Up @@ -138,6 +143,7 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
int width = in_dims[3];

int rois_num = rois->dims()[0];

if (rois_num == 0) return;

int output_size = out->numel();
Expand Down