Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support rectangle padding, stride, window and input for PoolProjection #115

Merged
merged 3 commits into from
Oct 9, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 58 additions & 30 deletions paddle/cuda/include/hl_cnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,23 @@ extern void hl_expand_feature2col(
* @param[in] width image width.
* @param[in] pooledH output image height.
* @param[in] pooledW output image width.
* @param[in] sizeX size of pooling window.
* @param[in] stride pooling stride.
* @param[in] start pooling start.
* @param[in] sizeX width of pooling window.
* @param[in] sizeY height of pooling window.
* @param[in] strideH pooling stride height.
* @param[in] strideW pooling stride width.
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[out] tgtData output data.
*
*/
extern void hl_maxpool_forward(
int frameCnt, const real* inputData, int channels,
int height, int width, int pooledH, int pooledW,
int sizeX, int stride, int start, real* tgtData);
const int frameCnt, const real* inputData,
const int channels,
const int height, const int width,
const int pooledH, const int pooledW,
const int sizeX, const int sizeY,
const int strideH, const int strideW,
const int paddingH, const int paddingW, real* tgtData);

/**
* @brief Maximum pool backward.
Expand All @@ -107,21 +114,28 @@ extern void hl_maxpool_forward(
* @param[in] width image width.
* @param[in] pooledH output image height.
* @param[in] pooledW output image width.
* @param[in] sizeX size of pooling window.
* @param[in] stride pooling stride.
* @param[in] start pooling start.
* @param[out] targetGrad output grad.
* @param[in] sizeX width of pooling window.
* @param[in] sizeY height of pooling window.
* @param[in] strideH pooling stride height.
* @param[in] strideW pooling stride width.
* @param[in] scaleA scale.
* @param[in] scaleB scale.
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[out] targetGrad output grad.
*
*/
extern void hl_maxpool_backward(
int frameCnt, const real* inputData,
const int frameCnt, const real* inputData,
const real* outData, const real* outGrad,
int channels, int height, int width,
int pooledH, int pooledW, int sizeX,
int stride, int start, real* targetGrad,
real scaleA, real scaleB);
const int channels, const int height,
const int width,
const int pooledH, const int pooledW,
const int sizeX, const int sizeY,
const int strideH, const int strideW,
const int paddingH, const int paddingW,
real scaleA, real scaleB,
real* targetGrad);

/**
* @brief Averge pool forward.
Expand All @@ -133,16 +147,23 @@ extern void hl_maxpool_backward(
* @param[in] width image width.
* @param[in] pooledH output image height.
* @param[in] pooledW output image width.
* @param[in] sizeX size of pooling window.
* @param[in] stride pooling stride.
* @param[in] start pooling start.
* @param[in] sizeX width of pooling window.
* @param[in] sizeY height of pooling window.
* @param[in] strideH pooling stride height.
* @param[in] strideW pooling stride width.
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[out] tgtData output data.
*
*/
extern void hl_avgpool_forward(
int frameCnt, const real* inputData, int channels,
int height, int width, int pooledH, int pooledW,
int sizeX, int stride, int start, real* tgtData);
const int frameCnt, const real* inputData,
const int channels,
const int height, const int width,
const int pooledH, const int pooledW,
const int sizeX, const int sizeY,
const int strideH, const int strideW,
const int paddingH, const int paddingW, real* tgtData);

/**
* @brief Maximum pool backward.
Expand All @@ -154,20 +175,27 @@ extern void hl_avgpool_forward(
* @param[in] width image width.
* @param[in] pooledH output image height.
* @param[in] pooledW output image width.
* @param[in] sizeX size of pooling window.
* @param[in] stride pooling stride.
* @param[in] start pooling start.
* @param[out] backGrad output grad.
* @param[in] sizeX width of pooling window.
* @param[in] sizeY height of pooling window.
* @param[in] strideH pooling stride height.
* @param[in] strideW pooling stride width.
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[in] scaleA scale.
* @param[in] scaleB scale.
* @param[out] backGrad output grad.
*
*/
extern void hl_avgpool_backward(
int frameCnt, const real* outGrad,
int channels, int height, int width,
int pooledH, int pooledW, int sizeX,
int stride, int start, real* backGrad,
real scaleA, real scaleB);
const int frameCnt, const real* outGrad,
const int channels, const int height,
const int width,
const int pooledH, const int pooledW,
const int sizeX, const int sizeY,
const int strideH, const int strideW,
int paddingH, int paddingW,
real scaleA, real scaleB,
real* backGrad);

/**
* @brief Cross-map-respose normalize forward.
Expand Down
48 changes: 32 additions & 16 deletions paddle/cuda/include/stub/hl_cnn_stub.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,45 @@ inline void hl_expand_feature2col(
real* dataCol) {}

inline void hl_maxpool_forward(
int frameCnt, const real* inputData, int channels,
int height, int width, int pooledH, int pooledW,
int sizeX, int stride, int start, real* tgtData) {}
const int frameCnt, const real* inputData,
const int channels,
const int height, const int width,
const int pooledH, const int pooledW,
const int sizeX, const int sizeY,
const int strideH, const int strideW,
const int paddingH, const int paddingW, real* tgtData) {}

inline void hl_maxpool_backward(
int frameCnt, const real* inputData,
const int frameCnt, const real* inputData,
const real* outData, const real* outGrad,
int channels, int height, int width,
int pooledH, int pooledW, int sizeX,
int stride, int start, real* targetGrad,
real scaleA, real scaleB) {}
const int channels, const int height,
const int width,
const int pooledH, const int pooledW,
const int sizeX, const int sizeY,
const int strideH, const int strideW,
const int paddingH, const int paddingW,
real scaleA, real scaleB,
real* targetGrad) {}

inline void hl_avgpool_forward(
int frameCnt, const real* inputData, int channels,
int height, int width, int pooledH, int pooledW,
int sizeX, int stride, int start, real* tgtData) {}
const int frameCnt, const real* inputData,
const int channels,
const int height, const int width,
const int pooledH, const int pooledW,
const int sizeX, const int sizeY,
const int strideH, const int strideW,
const int paddingH, const int paddingW, real* tgtData) {}

inline void hl_avgpool_backward(
int frameCnt, const real* outGrad,
int channels, int height, int width,
int pooledH, int pooledW, int sizeX,
int stride, int start, real* backGrad,
real scaleA, real scaleB) {}
const int frameCnt, const real* outGrad,
const int channels, const int height,
const int width,
const int pooledH, const int pooledW,
const int sizeX, const int sizeY,
const int strideH, const int strideW,
int paddingH, int paddingW,
real scaleA, real scaleB,
real* backGrad) {}

inline void hl_CMRNorm_forward(
size_t frameCnt, const real* in, real* scale, real* out,
Expand Down
Loading