From 5c3f7713b33a63be955c79d6d547d49bfa2aa617 Mon Sep 17 00:00:00 2001 From: Sergey Zagoruyko Date: Thu, 3 Dec 2015 15:21:34 +0100 Subject: [PATCH] ROIPooling v2 --- CMakeLists.txt | 2 +- ROIPooling.cu | 271 ++++++++++++++++++++++++++++++++++++ ROIPooling.lua | 53 +++++++ SpatialStochasticPooling.cu | 15 +- common.h | 18 +++ ffi.lua | 10 ++ init.lua | 1 + test/test_jacobian.lua | 52 +++++++ 8 files changed, 407 insertions(+), 15 deletions(-) create mode 100644 ROIPooling.cu create mode 100644 ROIPooling.lua create mode 100644 common.h diff --git a/CMakeLists.txt b/CMakeLists.txt index b4e95c2..fa4f2a3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,7 +24,7 @@ LIST(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA}) INCLUDE_DIRECTORIES("${Torch_INSTALL_INCLUDE}/THC") LINK_DIRECTORIES("${Torch_INSTALL_LIB}") -SET(src-cuda SpatialStochasticPooling.cu SpatialCrossResponseNormalization.cu) +FILE(GLOB src-cuda *.cu) FILE(GLOB luasrc *.lua) CUDA_ADD_LIBRARY(inn MODULE ${src-cuda}) diff --git a/ROIPooling.cu b/ROIPooling.cu new file mode 100644 index 0000000..547c3ed --- /dev/null +++ b/ROIPooling.cu @@ -0,0 +1,271 @@ +// ------------------------------------------------------------------ +// Fast R-CNN +// Copyright (c) 2015 Microsoft +// Licensed under The MIT License [see fast-rcnn/LICENSE for details] +// Written by Ross Girshick +// ------------------------------------------------------------------ + +// Torch port: +// IMAGINE, Sergey Zagoruyko, Francisco Massa, 2015 + +#include "THC.h" +#include +#include + +#include "common.h" + + +using std::max; +using std::min; + + +template +__global__ void ROIPoolForward(const int nthreads, const Dtype* bottom_data, + const Dtype spatial_scale, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const Dtype* bottom_rois, Dtype* top_data, int* argmax_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + bottom_rois += n * 5; + int roi_batch_ind = (bottom_rois[0] - 1); + int roi_start_w = round((bottom_rois[1] - 1) * spatial_scale); + int roi_start_h = round((bottom_rois[2] - 1)* spatial_scale); + int roi_end_w = round((bottom_rois[3] - 1) * spatial_scale); + int roi_end_h = round((bottom_rois[4] - 1) * spatial_scale); + + // Force malformed ROIs to be 1x1 + int roi_width = max(roi_end_w - roi_start_w + 1, 1); + int roi_height = max(roi_end_h - roi_start_h + 1, 1); + Dtype bin_size_h = static_cast(roi_height) + / static_cast(pooled_height); + Dtype bin_size_w = static_cast(roi_width) + / static_cast(pooled_width); + + int hstart = static_cast(floor(static_cast(ph) + * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) + * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) + * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) + * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, 0), height); + hend = min(max(hend + roi_start_h, 0), height); + wstart = min(max(wstart + roi_start_w, 0), width); + wend = min(max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + Dtype maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + bottom_data += (roi_batch_ind * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int bottom_index = h * width + w; + if (bottom_data[bottom_index] > maxval) { + maxval = bottom_data[bottom_index]; + maxidx = bottom_index; + } + } + } + top_data[index] = maxval; + argmax_data[index] = maxidx; + } +} + +extern "C" +void inn_ROIPooling_updateOutput(THCState *state, + THCudaTensor *output, THCudaTensor *indices, + THCudaTensor *data, THCudaTensor* rois, int W, int H, double spatial_scale) +{ + THAssert(THCudaTensor_nDimension(state, data) == 4); + THAssert(THCudaTensor_nDimension(state, rois) == 2 && rois->size[1] == 5); + THAssert(THCudaTensor_isContiguous(state, data)); + THAssert(THCudaTensor_isContiguous(state, rois)); + long num_rois = rois->size[0]; + long nInputPlane = data->size[1]; + THCudaTensor_resize4d(state, output, num_rois, nInputPlane, H, W); + THCudaTensor_resize4d(state, indices, num_rois, nInputPlane, H, W); + + long count = THCudaTensor_nElement(state, output); + + ROIPoolForward<<>>( + count, + THCudaTensor_data(state, data), + spatial_scale, nInputPlane, data->size[2], data->size[3], H, W, + THCudaTensor_data(state, rois), + THCudaTensor_data(state, output), + (int*)THCudaTensor_data(state, indices) + ); + + // check for errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in inn_ROIPooling_updateOutput: %s\n", cudaGetErrorString(err)); + THError("aborting"); + } +} + +template +__global__ void ROIPoolForwardV2(const int nthreads, const Dtype* bottom_data, + const Dtype spatial_scale, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const Dtype* bottom_rois, Dtype* top_data, int* argmax_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + bottom_rois += n * 5; + int roi_batch_ind = (bottom_rois[0] - 1); + int roi_start_w = round((bottom_rois[1] - 1) * spatial_scale); + int roi_start_h = round((bottom_rois[2] - 1)* spatial_scale); + int roi_end_w = round((bottom_rois[3] - 1) * spatial_scale) - 1; + int roi_end_h = round((bottom_rois[4] - 1) * spatial_scale) - 1; + + // Force malformed ROIs to be 1x1 + int roi_width = max(roi_end_w - roi_start_w + 1, 1); + int roi_height = max(roi_end_h - roi_start_h + 1, 1); + Dtype bin_size_h = static_cast(roi_height) + / static_cast(pooled_height); + Dtype bin_size_w = static_cast(roi_width) + / static_cast(pooled_width); + + int hstart = static_cast(round(static_cast(ph) + * bin_size_h)); + int wstart = static_cast(round(static_cast(pw) + * bin_size_w)); + int hend = static_cast(round(static_cast(ph + 1) + * bin_size_h)); + int wend = static_cast(round(static_cast(pw + 1) + * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, 0), height); + hend = min(max(hend + roi_start_h, 0), height); + wstart = min(max(wstart + roi_start_w, 0), width); + wend = min(max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + Dtype maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + bottom_data += (roi_batch_ind * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int bottom_index = h * width + w; + if (bottom_data[bottom_index] > maxval) { + maxval = bottom_data[bottom_index]; + maxidx = bottom_index; + } + } + } + top_data[index] = maxval; + argmax_data[index] = maxidx; + } +} + +extern "C" +void inn_ROIPooling_updateOutputV2(THCState *state, + THCudaTensor *output, THCudaTensor *indices, + THCudaTensor *data, THCudaTensor* rois, int W, int H, double spatial_scale) +{ + THAssert(THCudaTensor_nDimension(state, data) == 4); + THAssert(THCudaTensor_nDimension(state, rois) == 2 && rois->size[1] == 5); + THAssert(THCudaTensor_isContiguous(state, data)); + THAssert(THCudaTensor_isContiguous(state, rois)); + long num_rois = rois->size[0]; + long nInputPlane = data->size[1]; + THCudaTensor_resize4d(state, output, num_rois, nInputPlane, H, W); + THCudaTensor_resize4d(state, indices, num_rois, nInputPlane, H, W); + + long count = THCudaTensor_nElement(state, output); + + ROIPoolForwardV2<<>>( + count, + THCudaTensor_data(state, data), + spatial_scale, nInputPlane, data->size[2], data->size[3], H, W, + THCudaTensor_data(state, rois), + THCudaTensor_data(state, output), + (int*)THCudaTensor_data(state, indices) + ); + + // check for errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in inn_ROIPooling_updateOutput: %s\n", cudaGetErrorString(err)); + THError("aborting"); + } +} + +template +__global__ void ROIPoolBackwardAtomic(const int nthreads, const Dtype* top_diff, + const int* argmax_data, const int num_rois, const Dtype spatial_scale, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, Dtype* bottom_diff, + const Dtype* bottom_rois) { + CUDA_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + bottom_rois += n * 5; + int roi_batch_ind = (bottom_rois[0] - 1); + int bottom_offset = (roi_batch_ind * channels + c) * height * width; + int top_offset = (n * channels + c) * pooled_height * pooled_width; + const Dtype* offset_top_diff = top_diff + top_offset; + Dtype* offset_bottom_diff = bottom_diff + bottom_offset; + const int* offset_argmax_data = argmax_data + top_offset; + + int argmax = offset_argmax_data[ph*pooled_width + pw]; + if(argmax != -1) { + atomicAdd(offset_bottom_diff + argmax, offset_top_diff[ph * pooled_width + pw]); + } + } +} + +extern "C" +void inn_ROIPooling_updateGradInputAtomic(THCState *state, + THCudaTensor *gradInput, THCudaTensor *indices, THCudaTensor *data, + THCudaTensor *gradOutput, THCudaTensor* rois, int W, int H, double spatial_scale) +{ + THAssert(THCudaTensor_nDimension(state, data) == 4); + THAssert(THCudaTensor_nDimension(state, rois) == 2 && rois->size[1] == 5); + THAssert(THCudaTensor_isContiguous(state, data)); + THAssert(THCudaTensor_isContiguous(state, rois)); + long num_rois = rois->size[0]; + long nInputPlane = data->size[1]; + THCudaTensor_resizeAs(state, gradInput, data); + THCudaTensor_zero(state, gradInput); + + long count = THCudaTensor_nElement(state, gradOutput); + + ROIPoolBackwardAtomic<<>>( + count, + THCudaTensor_data(state, gradOutput), + (int*)THCudaTensor_data(state, indices), + num_rois, spatial_scale, nInputPlane, data->size[2], data->size[3], H, W, + THCudaTensor_data(state, gradInput), + THCudaTensor_data(state, rois) + ); + + // check for errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in inn_ROIPooling_updateGradInputAtomic: %s\n", cudaGetErrorString(err)); + THError("aborting"); + } +} diff --git a/ROIPooling.lua b/ROIPooling.lua new file mode 100644 index 0000000..af2b00c --- /dev/null +++ b/ROIPooling.lua @@ -0,0 +1,53 @@ +local ROIPooling,parent = torch.class('inn.ROIPooling', 'nn.Module') +local C = inn.C + +function ROIPooling:__init(W,H,spatial_scale) + parent.__init(self) + assert(W and H, 'W and H have to be provided') + self.W = W + self.H = H + self.spatial_scale = spatial_scale or 1 + self.gradInput = {} + self.indices = torch.Tensor() + self.v2 = true +end + +function ROIPooling:setSpatialScale(scale) + self.spatial_scale = scale + return self +end + +function ROIPooling:updateOutput(input) + assert(#input == 2) + local data = input[1] + local rois = input[2] + + if self.v2 then + C.inn_ROIPooling_updateOutputV2(cutorch.getState(), + self.output:cdata(), self.indices:cdata(), data:cdata(), rois:cdata(), + self.W, self.H, self.spatial_scale) + else + C.inn_ROIPooling_updateOutput(cutorch.getState(), + self.output:cdata(), self.indices:cdata(), data:cdata(), rois:cdata(), + self.W, self.H, self.spatial_scale) + end + return self.output +end + +function ROIPooling:updateGradInput(input,gradOutput) + local data = input[1] + local rois = input[2] + + self.gradInput_boxes = self.gradInput_boxes or data.new() + self.gradInput_rois = self.gradInput_rois or data.new() + + C.inn_ROIPooling_updateGradInputAtomic(cutorch.getState(), + self.gradInput_boxes:cdata(), self.indices:cdata(), data:cdata(), + gradOutput:cdata(), rois:cdata(), self.W, self.H, self.spatial_scale) + + self.gradInput_rois:resizeAs(rois):zero() + + self.gradInput = {self.gradInput_boxes, self.gradInput_rois} + + return self.gradInput +end diff --git a/SpatialStochasticPooling.cu b/SpatialStochasticPooling.cu index 9cf6a92..fbc4ea2 100644 --- a/SpatialStochasticPooling.cu +++ b/SpatialStochasticPooling.cu @@ -1,19 +1,6 @@ #include -// CUDA: grid stride looping -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ - i < (n); \ - i += blockDim.x * gridDim.x) - -// Use 1024 threads per block, which requires cuda sm_2x or above -const int CUDA_NUM_THREADS = 1024; - -// CUDA: number of blocks for threads. -inline int GET_BLOCKS(const int N) { - return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; -} - +#include "common.h" // kernels borrowed from Caffe diff --git a/common.h b/common.h new file mode 100644 index 0000000..59cf80c --- /dev/null +++ b/common.h @@ -0,0 +1,18 @@ +#ifndef INN_COMMON_H +#define INN_COMMON_H + +// CUDA: grid stride looping +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +// Use 1024 threads per block, which requires cuda sm_2x or above +const int CUDA_NUM_THREADS = 1024; + +// CUDA: number of blocks for threads. +inline int GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +#endif diff --git a/ffi.lua b/ffi.lua index d959294..02a29dc 100644 --- a/ffi.lua +++ b/ffi.lua @@ -24,6 +24,16 @@ void LRNforward(struct THCState* state, THCudaTensor* input, void LRNbackward(struct THCState* state, THCudaTensor* input, THCudaTensor* output, THCudaTensor* gradOutput, THCudaTensor* gradInput, THCudaTensor* scale, int local_size, float alpha, float beta, float k); + +void inn_ROIPooling_updateOutput(THCState *state, + THCudaTensor *output, THCudaTensor *indices, + THCudaTensor *data, THCudaTensor* rois, int W, int H, double spatial_scale); +void inn_ROIPooling_updateOutputV2(THCState *state, + THCudaTensor *output, THCudaTensor *indices, + THCudaTensor *data, THCudaTensor* rois, int W, int H, double spatial_scale); +void inn_ROIPooling_updateGradInputAtomic(THCState *state, + THCudaTensor *gradInput, THCudaTensor *indices, THCudaTensor *data, + THCudaTensor *gradOutput, THCudaTensor* rois, int W, int H, double spatial_scale); ]] inn.C = ffi.load(package.searchpath('libinn', package.cpath)) diff --git a/init.lua b/init.lua index 8c29ba8..4ea020b 100644 --- a/init.lua +++ b/init.lua @@ -8,3 +8,4 @@ include 'SpatialCrossResponseNormalization.lua' include 'MeanSubtraction.lua' include 'SpatialPyramidPooling.lua' include 'SpatialSameResponseNormalization.lua' +include 'ROIPooling.lua' diff --git a/test/test_jacobian.lua b/test/test_jacobian.lua index 1482478..f0f57b8 100644 --- a/test/test_jacobian.lua +++ b/test/test_jacobian.lua @@ -86,6 +86,58 @@ function inntest.SpatialSameResponseNormalization() mytester:assertlt(err, precision, 'error on state (Batch) ') end +function randROI(sz, n) + assert(sz:size()==4, "need 4d size") + local roi=torch.Tensor(n,5) + for i=1,n do + idx=torch.randperm(sz[1])[1] + y=torch.randperm(sz[3])[{{1,2}}]:sort() + x=torch.randperm(sz[4])[{{1,2}}]:sort() + roi[{i,{}}] = torch.Tensor({idx,x[1],y[1],x[2],y[2]}) + end + return roi +end + +function testJacobianWithRandomROI(cls, v2) + --pooling grid size + local w=4; + local h=4; + --input size + local W=w*2; + local H=h*2; + + local batchSize = 3 + local numRoi = batchSize + local numRepeat = 3 + + torch.manualSeed(0) + for i=1,numRepeat do + local input = torch.rand(batchSize, 1, H, W); + local roi = randROI(input:size(), numRoi) + local module = cls.new(h, w, 1, roi) + module.v2 = v2 + local err = jac.testJacobian(module, input, nil, nil, 1e-3) + mytester:assertlt(err, precision, 'error on ROIPooling '..(v2 and 'v2' or 'v1')) + end +end + +function inntest.ROIPooling() + local FixedROIPooling, parent = torch.class('FixedROIPooling', 'inn.ROIPooling') + function FixedROIPooling:__init(W, H, s, roi) + self.roi = roi + parent.__init(self, W, H, s) + self:cuda() + end + + function FixedROIPooling:updateOutput(input) + return parent.updateOutput(self,{input:cuda(), self.roi}) + end + function FixedROIPooling:updateGradInput(input, gradOutput) + return parent.updateGradInput(self,{input:cuda(), self.roi}, gradOutput)[1] + end + + testJacobianWithRandomROI(FixedROIPooling, true) +end jac = nn.Jacobian mytester:add(inntest)