-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GemmConvMobileFunction(optimized for mobile) #7034
Changes from 5 commits
dbf1d75
d775895
1954794
a850dec
f453b71
b7c4b58
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -126,14 +126,163 @@ class GemmConvFunction : public ConvFunctionBase { | |
inputData += inputChannels * inputHeight * inputWidth; | ||
outputData += outputChannels * outputHeight * outputWidth; | ||
} | ||
} | ||
}; | ||
|
||
#ifdef PADDLE_MOBILE_INFERENCE | ||
if (Device == DEVICE_TYPE_CPU) { | ||
memory_.reset(); | ||
|
||
/* | ||
* \brief Forward calculation of convolution, optimized for mobile. | ||
*/ | ||
template <DeviceType Device> | ||
class GemmConvMobileFunction : public ConvFunctionBase { | ||
public: | ||
void init(const FuncConfig& config) override { | ||
ConvFunctionBase::init(config); | ||
} | ||
|
||
void check(const BufferArgs& inputs, const BufferArgs& outputs) override { | ||
const TensorShape& input = inputs[0].shape(); | ||
const TensorShape& filter = inputs[1].shape(); | ||
const TensorShape& output = outputs[0].shape(); | ||
checkShape(input, filter, output); | ||
} | ||
|
||
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { | ||
CHECK_EQ(numInputs_, inputs.size()); | ||
CHECK_EQ(numOutputs_, outputs.size()); | ||
check(inputs, outputs); | ||
// TODO(hedaoyuan): Need to define some index macros, | ||
// to avoid useing 0 and 1. | ||
const TensorShape& input = inputs[0].shape(); | ||
const TensorShape& filter = inputs[1].shape(); | ||
const TensorShape& output = outputs[0].shape(); | ||
|
||
real beta; | ||
if (outputs[0].getArgType() == ADD_TO) { | ||
beta = 1.0; | ||
} else { | ||
beta = 0.0; | ||
} | ||
|
||
size_t batchSize = input[0]; | ||
size_t inputChannels = input[1]; | ||
size_t inputHeight = input[2]; | ||
size_t inputWidth = input[3]; | ||
size_t filterHeight = getFilterHeight(filter); | ||
size_t filterWidth = getFilterWidth(filter); | ||
size_t outputChannels = output[1]; | ||
size_t outputHeight = output[2]; | ||
size_t outputWidth = output[3]; | ||
|
||
real* inputData = inputs[0].data<real>(); | ||
real* filterData = inputs[1].data<real>(); | ||
real* outputData = outputs[0].data<real>(); | ||
bool needIm2col = isNeedIm2col(filter); | ||
|
||
TensorShape imShape = | ||
TensorShape({inputChannels / groups_, inputHeight, inputWidth}); | ||
|
||
TensorShape colShape; | ||
real* colData = NULL; | ||
|
||
size_t colHeight = inputChannels / groups_ * filterHeight * filterWidth; | ||
size_t colWidth = outputHeight * outputWidth; | ||
// Max col matrix height 256, Max col matrix width 1024 | ||
size_t stepColHeight = std::min(colHeight, (size_t)256); | ||
size_t stepColWidth = std::min(colWidth, (size_t)2048); | ||
|
||
if (needIm2col) { | ||
colShape = TensorShape({inputChannels / groups_, | ||
filterHeight, | ||
filterWidth, | ||
outputHeight, | ||
outputWidth}); | ||
|
||
resizeBuffer<Device>(stepColHeight * stepColWidth * sizeof(real)); | ||
colData = reinterpret_cast<real*>(memory_->getBuf()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. Add release the memory. |
||
} | ||
|
||
Im2ColMobileFunctor<real> im2col; | ||
size_t inputOffset = imShape.getElements(); | ||
size_t outputOffset = | ||
(outputChannels / groups_) * outputHeight * outputWidth; | ||
size_t filterOffset = filter.getElements() / groups_; | ||
|
||
int nStride = colWidth; | ||
int kStride = colHeight; | ||
for (size_t i = 0; i < batchSize; i++) { | ||
for (size_t g = 0; g < groups_; g++) { | ||
if (needIm2col) { | ||
real beta_ = beta; | ||
for (size_t colHeightStart = 0; colHeightStart < colHeight; | ||
colHeightStart += stepColHeight) { | ||
for (size_t colWidthStart = 0; colWidthStart < colWidth; | ||
colWidthStart += stepColWidth) { | ||
int N = std::min(colWidth - colWidthStart, stepColWidth); | ||
int K = std::min(colHeight - colHeightStart, stepColHeight); | ||
// im2col | ||
im2col(inputData + g * inputOffset, | ||
imShape, | ||
colData, | ||
colShape, | ||
strideH(), | ||
strideW(), | ||
paddingH(), | ||
paddingW(), | ||
dilationH(), | ||
dilationW(), | ||
colHeightStart, | ||
K, | ||
colWidthStart, | ||
N); | ||
|
||
// gemm | ||
int M = outputChannels / groups_; | ||
BlasGemm<Device, real>::compute( | ||
false, | ||
false, | ||
M, | ||
N, | ||
K, | ||
1.0f, | ||
filterData + g * filterOffset + colHeightStart, | ||
kStride, | ||
colData, | ||
N, | ||
beta_, | ||
outputData + g * outputOffset + colWidthStart, | ||
nStride); | ||
} | ||
beta_ = 1.0; | ||
} | ||
} else { | ||
int M = outputChannels / groups_; | ||
int N = outputHeight * outputWidth; | ||
int K = inputChannels / groups_ * filterHeight * filterWidth; | ||
BlasGemm<Device, real>::compute(false, | ||
false, | ||
M, | ||
N, | ||
K, | ||
1.0f, | ||
filterData + g * filterOffset, | ||
K, | ||
inputData + g * inputOffset, | ||
N, | ||
beta, | ||
outputData + g * outputOffset, | ||
N); | ||
} | ||
} | ||
inputData += inputChannels * inputHeight * inputWidth; | ||
outputData += outputChannels * outputHeight * outputWidth; | ||
} | ||
#endif | ||
} | ||
}; | ||
|
||
#endif | ||
|
||
/* | ||
* \brief Backward input calculation of convolution. | ||
*/ | ||
|
@@ -348,7 +497,11 @@ class GemmConvGradFilterFunction : public ConvFunctionBase { | |
} | ||
}; | ||
|
||
#ifdef PADDLE_MOBILE_INFERENCE | ||
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvMobileFunction); | ||
#else | ||
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction); | ||
#endif | ||
REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction); | ||
REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction); | ||
#ifdef PADDLE_WITH_CUDA | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,4 +98,54 @@ class Col2ImFunctor { | |
int dilationWidth = 1); | ||
}; | ||
|
||
template <class T> | ||
class Im2ColMobileFunctor { | ||
public: | ||
void operator()(const T* imData, | ||
const TensorShape& imShape, | ||
T* colData, | ||
const TensorShape& colShape, | ||
int strideHeight, | ||
int strideWidth, | ||
int paddingHeight, | ||
int paddingWidth, | ||
int dilationHeight, | ||
int dilationWidth, | ||
int colHeightStart, | ||
int colHeightSize, | ||
int colWidthStart, | ||
int colWidthSize) { | ||
int inputHeight = imShape[1]; | ||
int inputWidth = imShape[2]; | ||
int filterHeight = colShape[1]; | ||
int filterWidth = colShape[2]; | ||
int outputWidth = colShape[4]; | ||
|
||
for (int colh = 0; colh < colHeightSize; colh++) { | ||
int wOffset = (colHeightStart + colh) % filterWidth; | ||
int hOffset = ((colHeightStart + colh) / filterWidth) % filterHeight; | ||
int c_im = (colHeightStart + colh) / filterWidth / filterHeight; | ||
|
||
for (int colw = 0; colw < colWidthSize; colw++) { | ||
int h = (colWidthStart + colw) / outputWidth; | ||
int w = (colWidthStart + colw) % outputWidth; | ||
|
||
int imRowIdx = h * strideHeight + hOffset * dilationHeight; | ||
int imColIdx = w * strideWidth + wOffset * dilationWidth; | ||
if ((imRowIdx - paddingHeight) < 0 || | ||
(imRowIdx - paddingHeight) >= inputHeight || | ||
(imColIdx - paddingWidth) < 0 || | ||
(imColIdx - paddingWidth) >= inputWidth) { | ||
colData[colh * colWidthSize + colw] = T(0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
} else { | ||
imRowIdx += c_im * inputHeight - paddingHeight; | ||
imColIdx -= paddingWidth; | ||
colData[colh * colWidthSize + colw] = | ||
imData[imRowIdx * inputWidth + imColIdx]; | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,4 +138,86 @@ TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor<DEVICE_TYPE_GPU, float>(); } | |
|
||
#endif | ||
|
||
template <class T> | ||
void TestIm2ColMobileFunctor() { | ||
for (size_t channels : {1, 5, 32}) { | ||
for (size_t inputHeight : {5, 33, 100}) { | ||
for (size_t inputWidth : {5, 32, 96}) { | ||
for (size_t filterHeight : {1, 5}) { | ||
for (size_t filterWidth : {3, 7}) { | ||
for (size_t stride : {1, 2}) { | ||
for (size_t padding : {0, 1}) { | ||
for (size_t dilation : {1, 3}) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe the test case can be reduced to speed up the unit testing time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
size_t filterSizeH = (filterHeight - 1) * dilation + 1; | ||
size_t filterSizeW = (filterWidth - 1) * dilation + 1; | ||
if (inputHeight + 2 * padding < filterSizeH || | ||
inputWidth + 2 * padding < filterSizeW) | ||
break; | ||
if (padding >= filterSizeH || padding >= filterSizeW) break; | ||
size_t outputHeight = | ||
(inputHeight - filterSizeH + 2 * padding) / stride + 1; | ||
size_t outputWidth = | ||
(inputWidth - filterSizeW + 2 * padding) / stride + 1; | ||
|
||
TensorShape imShape = | ||
TensorShape({channels, inputHeight, inputWidth}); | ||
TensorShape colShape1 = TensorShape({channels, | ||
filterHeight, | ||
filterWidth, | ||
outputHeight, | ||
outputWidth}); | ||
|
||
size_t height = channels * filterHeight * filterWidth; | ||
size_t width = outputHeight * outputWidth; | ||
VectorPtr input1 = | ||
Vector::create(imShape.getElements(), false); | ||
VectorPtr input2 = | ||
Vector::create(imShape.getElements(), false); | ||
MatrixPtr output1 = | ||
Matrix::create(height, width, false, false); | ||
MatrixPtr output2 = | ||
Matrix::create(height, width, false, false); | ||
input1->uniform(0.001, 1); | ||
input2->copyFrom(*input1); | ||
|
||
Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, T> im2Col1; | ||
Im2ColMobileFunctor<T> im2Col2; | ||
im2Col1(input1->getData(), | ||
imShape, | ||
output1->getData(), | ||
colShape1, | ||
stride, | ||
stride, | ||
padding, | ||
padding, | ||
dilation, | ||
dilation); | ||
im2Col2(input2->getData(), | ||
imShape, | ||
output2->getData(), | ||
colShape1, | ||
stride, | ||
stride, | ||
padding, | ||
padding, | ||
dilation, | ||
dilation, | ||
0, | ||
height, | ||
0, | ||
width); | ||
|
||
autotest::TensorCheckEqual(*output1, *output2); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
TEST(Im2ColFunctor, Mobile) { TestIm2ColMobileFunctor<float>(); } | ||
|
||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.