Skip to content

Commit

Permalink
tyf/change pooling and relu (DeepLink-org#765)
Browse files Browse the repository at this point in the history
* change memory format of pooling and relu
  • Loading branch information
Bonbon-Tang authored Dec 28, 2023
1 parent 724c640 commit 385ce67
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 23 deletions.
18 changes: 18 additions & 0 deletions impl/camb/convert_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,21 @@

- diopiBatchNormStats:
layout: NLC, NHWC, NDHWC

#* Ops below are not neccesary to convert format *#
# next version will be changed,it is better to do nothing in adaptor
- diopiMaxPool2dWithIndices:
layout: NHWC

- diopiMaxPool2d:
layout: NHWC

- diopiMaxPool2dBackward:
layout: NHWC

- diopiReluInp:
layout: NLC, NHWC, NDHWC

- diopiThresholdBackward:
layout: NLC, NHWC, NDHWC
#* Ops above are not neccesary to convert format *#
84 changes: 61 additions & 23 deletions impl/camb/functions/max_pool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@ diopiError_t diopiMaxPool2d(diopiContextHandle_t ctx, diopiTensorHandle_t out, d

DiopiTensor inputTr(input);
DiopiTensor outTr(out);

DIOPI_CHECK(inputTr.dim() == 3 || inputTr.dim() == 4, "non-empty 3D or 4D (batch mode) tensor expected for input");
DIOPI_CHECK(inputTr.dim() == 3 || inputTr.dim() == 4, "only support 3D or 4D tensor for input");
if (inputTr.dim() == 3) {
DIOPI_CHECK(inputTr.isContiguous(diopiMemoryFormat_t::Contiguous), "only support contiguous for 3D input");
} else {
DIOPI_CHECK(inputTr.isContiguous(diopiMemoryFormat_t::ChannelsLast), "only support ChannelsLast for 4D input");
}

std::vector<DiopiTensor*> pTensors{&inputTr};
DIOPI_CALL(autoCastTensorType(ctx, pTensors, {diopi_dtype_float16, diopi_dtype_float32}));

int inDim = inputTr.dim();
if (inputTr.dim() == 3) {
inputTr.unsqueeze(0);
outTr.unsqueeze(0);
Expand All @@ -34,11 +39,15 @@ diopiError_t diopiMaxPool2d(diopiContextHandle_t ctx, diopiTensorHandle_t out, d
if (inputTr.dtype() != outTr.dtype()) {
outTmpTr = requiresTensor(ctx, outTr.shape(), inputTr.dtype());
}

std::vector<int64_t> inputDim = inputTr.shape();
std::vector<int64_t> outDim = outTmpTr.shape();
CnnlTensorDesc inputDesc(inputTr, CNNL_LAYOUT_NCHW);
CnnlTensorDesc outDesc(outTmpTr, CNNL_LAYOUT_NCHW);

cnnlTensorLayout_t layout = CNNL_LAYOUT_NHWC;
if (inDim == 3) {
layout = CNNL_LAYOUT_NCHW;
}
CnnlTensorDesc inputDesc(inputTr, layout);
CnnlTensorDesc outDesc(outTmpTr, layout);

const int64_t kernelH = kernelSize.data[0];
const int64_t kernelW = kernelSize.len == 1 ? kernelH : kernelSize.data[1];
Expand All @@ -51,6 +60,7 @@ diopiError_t diopiMaxPool2d(diopiContextHandle_t ctx, diopiTensorHandle_t out, d
strideH = stride.data[0];
strideW = stride.len == 1 ? strideH : stride.data[1];
}

const int64_t padH = padding.data[0];
const int64_t padW = padding.len == 1 ? padH : padding.data[1];
const int64_t dilation0 = dilation.data[0];
Expand Down Expand Up @@ -85,7 +95,6 @@ diopiError_t diopiMaxPool2d(diopiContextHandle_t ctx, diopiTensorHandle_t out, d
if (outTmpTr.dtype() != outTr.dtype()) {
DIOPI_CALL(dataTypeCast(ctx, outTr, outTmpTr));
}

return diopiSuccess;
}

Expand All @@ -100,12 +109,17 @@ diopiError_t diopiMaxPool2dWithIndices(diopiContextHandle_t ctx, diopiTensorHand
DIOPI_CHECK(inputTr.dim() == 3 || inputTr.dim() == 4, "non-empty 3D or 4D (batch mode) tensor expected for input");

std::vector<DiopiTensor*> pTensors{&inputTr};
DIOPI_CALL(autoCastTensorType(ctx, pTensors, {diopi_dtype_float16, diopi_dtype_float32}));

if (inputTr.dtype() != diopi_dtype_float16 && inputTr.dtype() != diopi_dtype_float32) {
DIOPI_CALL(autoCastTensorType(ctx, pTensors, {diopi_dtype_float16, diopi_dtype_float32}));
}
int inDim = inputTr.dim();
if (inputTr.dim() == 3) {
DIOPI_CHECK(inputTr.isContiguous(diopiMemoryFormat_t::Contiguous), "only support contiguous for 3D input");
inputTr.unsqueeze(0);
indicesTr.unsqueeze(0);
outTr.unsqueeze(0);
} else { // dim() == 4
DIOPI_CHECK(inputTr.isContiguous(diopiMemoryFormat_t::ChannelsLast), "only support contiguous for 3D input");
}

DiopiTensor outTmpTr = outTr;
Expand All @@ -120,9 +134,14 @@ diopiError_t diopiMaxPool2dWithIndices(diopiContextHandle_t ctx, diopiTensorHand

std::vector<int64_t> inputDim = inputTr.shape();
std::vector<int64_t> outDim = outTmpTr.shape();
CnnlTensorDesc inputDesc(inputTr, CNNL_LAYOUT_NCHW);
CnnlTensorDesc indicesDesc(indicesTmpTr, CNNL_LAYOUT_NCHW);
CnnlTensorDesc outDesc(outTmpTr, CNNL_LAYOUT_NCHW);

cnnlTensorLayout_t layout = CNNL_LAYOUT_NHWC;
if (inDim == 3) {
layout = CNNL_LAYOUT_NCHW;
}
CnnlTensorDesc inputDesc(inputTr, layout);
CnnlTensorDesc indicesDesc(indicesTmpTr, layout);
CnnlTensorDesc outDesc(outTmpTr, layout);

const int64_t kernelH = kernelSize.data[0];
const int64_t kernelW = kernelSize.len == 1 ? kernelH : kernelSize.data[1];
Expand Down Expand Up @@ -217,30 +236,51 @@ diopiError_t diopiMaxPool2dBackward(diopiContextHandle_t ctx, diopiTensorHandle_
DiopiTensor gradInputTr(gradInput);
DiopiTensor gradOutputTr(gradOutput);
DiopiTensor indicesTr(indices);

DIOPI_CHECK(inputTr.dim() == 3 || inputTr.dim() == 4, "non-empty 3D or 4D (batch mode) tensor expected for input");

if (inputTr.dim() == 3) {
DIOPI_CHECK(inputTr.dim() == indicesTr.dim() && inputTr.dim() == gradOutputTr.dim() && inputTr.dim() == gradInputTr.dim(),
"the shapes of input(%ld), indices(%ld), gradOutput(%ld) and gradInput(%ld) should be same",
inputTr.dim(),
indicesTr.dim(),
gradOutputTr.dim(),
gradInputTr.dim());
DIOPI_CHECK(inputTr.dim() == 3 || inputTr.dim() == 4, "3D or 4D (batch mode) tensor expected for input");
bool is3dim = inputTr.dim() == 3;
if (is3dim) {
DIOPI_CHECK(inputTr.isContiguous(diopiMemoryFormat_t::Contiguous), "only support contiguous for 3D input");
DIOPI_CHECK(indicesTr.isContiguous(diopiMemoryFormat_t::Contiguous), "only support contiguous for 3D indices");
DIOPI_CHECK(gradInputTr.isContiguous(diopiMemoryFormat_t::Contiguous), "only support contiguous for 3D gradInputTr");
DIOPI_CHECK(gradOutputTr.isContiguous(diopiMemoryFormat_t::Contiguous), "only support contiguous for 3D gradOutput");
inputTr.unsqueeze(0);
indicesTr.unsqueeze(0);
gradInputTr.unsqueeze(0);
gradOutputTr.unsqueeze(0);
} else { // dim() == 4
DIOPI_CHECK(inputTr.isContiguous(diopiMemoryFormat_t::ChannelsLast), "only support ChannelsLast for 4D input");
DIOPI_CHECK(indicesTr.isContiguous(diopiMemoryFormat_t::ChannelsLast), "only support ChannelsLast for 4D indices");
DIOPI_CHECK(gradInputTr.isContiguous(diopiMemoryFormat_t::ChannelsLast), "only support ChannelsLast for 4D gradInputTr");
DIOPI_CHECK(gradOutputTr.isContiguous(diopiMemoryFormat_t::ChannelsLast), "only support ChannelsLast for 4D gradOutput");
}

std::vector<DiopiTensor*> pTensors{&inputTr, &gradOutputTr};
DIOPI_CALL(autoCastTensorType(ctx, pTensors, {diopi_dtype_float16, diopi_dtype_float32}));
if (inputTr.dtype() != gradOutputTr.dtype() || (inputTr.dtype() != diopi_dtype_float16 && inputTr.dtype() != diopi_dtype_float32)) {
DIOPI_CALL(autoCastTensorType(ctx, pTensors, {diopi_dtype_float16, diopi_dtype_float32}));
}

if (inputTr.dtype() == diopi_dtype_float16) {
DIOPI_CALL(dataTypeCast(ctx, indicesTr, diopi_dtype_int16));
} else {
DIOPI_CALL(dataTypeCast(ctx, indicesTr, diopi_dtype_int32));
}

diopiMemoryFormat_t memoryFormat = diopiMemoryFormat_t::ChannelsLast;
DIOPI_CALL(contiguous(ctx, inputTr, memoryFormat));
DIOPI_CALL(contiguous(ctx, gradOutputTr, memoryFormat));
DIOPI_CALL(contiguous(ctx, indicesTr, memoryFormat));
DiopiTensor gradInputTmpTr = requiresTensor(ctx, gradInputTr.shape(), inputTr.dtype(), memoryFormat);
// for 3 dim input, it is contiguous, and needs to convert to channelslast for camb kernel.
if (is3dim) {
DIOPI_CALL(contiguous(ctx, inputTr, memoryFormat));
DIOPI_CALL(contiguous(ctx, gradOutputTr, memoryFormat));
DIOPI_CALL(contiguous(ctx, indicesTr, memoryFormat));
}
DiopiTensor gradInputTmpTr = gradInputTr;
if (is3dim) {
gradInputTmpTr = requiresTensor(ctx, gradInputTr.shape(), inputTr.dtype(), memoryFormat);
}

std::vector<int64_t> inputDim = inputTr.shape();
std::vector<int64_t> gradOutputDim = gradOutputTr.shape();
Expand Down Expand Up @@ -299,8 +339,6 @@ diopiError_t diopiMaxPool2dBackward(diopiContextHandle_t ctx, diopiTensorHandle_
gradInputDesc.get(),
gradInputTmpTr.data()));

// Channels last -> contiguous
DIOPI_CALL(contiguous(ctx, gradInputTmpTr, diopiMemoryFormat_t::Contiguous));
DIOPI_CALL(diopiCopyInp(ctx, gradInputTmpTr.tensorHandle(), gradInputTr.tensorHandle()));

return diopiSuccess;
Expand Down

0 comments on commit 385ce67

Please sign in to comment.