Skip to content

Commit

Permalink
Backward data (#246)
Browse files Browse the repository at this point in the history
* bwd data generate multi-kernel according parameters

* bwd data run according gemm id

* host calculate gemm size using gemm id

* filte available gemm id

* fix unit test for add gemm id in gridewise

* bwd data vector load

* fix unit test

* add unit test and fix bug

* add comment and change ADT

* change comments

* release 2 e2e config of backward data

* add example for comment

* recover getKernelCount

* calculate gemmid from kernel id

* add kernel_id parameter in mlir-miopen-lib-test

* unint test add e2e test in mlir-miopen-lib-test

* Revert "unint test add e2e test in mlir-miopen-lib-test"

This reverts commit 918ba9b.

* Revert "add kernel_id parameter in mlir-miopen-lib-test"

This reverts commit 0459c15.

* add e2e of bwd in mlir-miopen-lib-test

* fix bwd unit test of mlir-miopen-lib-test

Co-authored-by: letaoqin <[email protected]>
Co-authored-by: Wen-Heng (Jack) Chung <[email protected]>
  • Loading branch information
3 people authored Jun 7, 2021
1 parent b755661 commit 22c4eb7
Show file tree
Hide file tree
Showing 22 changed files with 239 additions and 114 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/MIOpen/Generator/Conv2dGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class Conv2dGenerator {
SmallVector<int64_t, 5> filterDimension;
SmallVector<int64_t, 5> inputDimension;
SmallVector<int64_t, 5> outputDimension;

int filterHeight;
int filterWidth;
};

Conv2dGenerator(const std::string &arch = "", int num_cu = 0,
Expand Down Expand Up @@ -98,6 +101,7 @@ class Conv2dGenerator {
});
return permutation;
}
int getBwdDataKernelCount() const;

// Generator config
Config config;
Expand Down
30 changes: 26 additions & 4 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -2407,7 +2407,7 @@ struct Conv2DRewritePattern : public OpRewritePattern<T> {

LogicalResult backwardData(T op, PatternRewriter &b) const {
auto loc = op.getLoc();

auto gemmIdAttr = op->template getAttrOfType<IntegerAttr>("gemm_id");
auto archAttr = op->template getAttrOfType<StringAttr>("arch");
auto numCuAttr = op->template getAttrOfType<IntegerAttr>("num_cu");

Expand Down Expand Up @@ -2500,8 +2500,8 @@ struct Conv2DRewritePattern : public OpRewritePattern<T> {
auto gcdStrideDilationH = math::gcd(strideH, dilationH);
auto gcdStrideDilationW = math::gcd(strideW, dilationW);

auto yTilda = dilationH / gcdStrideDilationH;
auto xTilda = dilationW / gcdStrideDilationW;
auto yTilda = strideH / gcdStrideDilationH;
auto xTilda = strideW / gcdStrideDilationW;

auto yDot = math::integer_divide_ceil(y, yTilda);
auto xDot = math::integer_divide_ceil(x, xTilda);
Expand All @@ -2522,7 +2522,28 @@ struct Conv2DRewritePattern : public OpRewritePattern<T> {
auto hTildaSlice = iHTildaRight - iHTildaLeft;
auto wTildaSlice = iWTildaRight - iWTildaLeft;

auto gemmId = 0;
auto getGemmId = [&](int kernelId) {
// kernelId 0 must be gemmId 0
if (kernelId <= 0)
return 0;

llvm::SmallVector<int> gemmIds;
for (int gemmId = 0; gemmId < yTilda * xTilda; gemmId++) {
// gemm_k size is different for each GEMM
const auto iYTilda = gemmId / xTilda;
const auto iXTilda = gemmId % xTilda;

auto yDotSlice = math::integer_divide_ceil(y - iYTilda, yTilda);
auto xDotSlice = math::integer_divide_ceil(x - iXTilda, xTilda);
// gemmK must > 0, otherwise not need to run
if (yDotSlice * xDotSlice > 0) {
gemmIds.push_back(gemmId);
}
}
assert(gemmIds.size() > kernelId);
return gemmIds[kernelId];
};
auto gemmId = getGemmId(gemmIdAttr.getInt());
auto iYTilda = gemmId / xTilda;
auto iXTilda = gemmId % xTilda;
auto yDotSlice = math::integer_divide_ceil(y - iYTilda, yTilda);
Expand Down Expand Up @@ -3766,6 +3787,7 @@ struct Conv2DRewritePattern : public OpRewritePattern<T> {

// Set attributes for gridwise_gemm op.
llvm::SmallVector<NamedAttribute, 8> gridwiseGemmAttrs{
b.getNamedAttr("gemm_id", gemmIdAttr),
b.getNamedAttr("arch", archAttr),
b.getNamedAttr("num_cu", numCuAttr),
b.getNamedAttr("filter_layout", filterLayoutAttr),
Expand Down
15 changes: 11 additions & 4 deletions mlir/include/mlir/Dialect/MIOpen/Tuning/ConvContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@ struct ConvolutionContext : SQLiteSerializable<ConvolutionContext> {
llvm::SmallVector<int64_t, 0> strideVal;
llvm::SmallVector<int64_t, 0> dilationVal;
llvm::SmallVector<int64_t, 0> paddingVal;
int gemmId;

ConvolutionContext(const llvm::SmallString<8> &architecture, int numCu,
miopen::ConvOpType op,
llvm::StringMap<std::pair<size_t, int64_t>> dim,
llvm::SmallVector<int64_t, 0> stride,
llvm::SmallVector<int64_t, 0> dilation,
llvm::SmallVector<int64_t, 0> padding)
llvm::SmallVector<int64_t, 0> padding, int gemmid)
: arch(architecture), num_cu(numCu), opType(op), dimIndexVal(dim),
strideVal(stride), dilationVal(dilation), paddingVal(padding) {}
strideVal(stride), dilationVal(dilation), paddingVal(padding),
gemmId(gemmid) {}

llvm::StringMap<std::pair<size_t, int64_t>> getDimIndexVal() const {
return dimIndexVal;
Expand Down Expand Up @@ -146,6 +148,11 @@ template <typename T> static ConvolutionContext populateConvContext(T &op) {

auto archVal = op->template getAttrOfType<StringAttr>("arch").getValue();
int numCuVal = op->template getAttrOfType<IntegerAttr>("num_cu").getInt();
auto gemmIdAttr = op->template getAttrOfType<IntegerAttr>("gemm_id");
int gemmId = 0;
if (gemmIdAttr) {
gemmId = gemmIdAttr.getInt();
}

llvm::StringMap<std::pair<size_t, int64_t>> dimIndexVal;

Expand Down Expand Up @@ -176,8 +183,8 @@ template <typename T> static ConvolutionContext populateConvContext(T &op) {
llvm::SmallVector<int64_t, 0> paddingVal;
populateSeqVal(paddingAttr, paddingVal);

return {archVal, numCuVal, opType, dimIndexVal,
strideVal, dilationVal, paddingVal};
return {archVal, numCuVal, opType, dimIndexVal,
strideVal, dilationVal, paddingVal, gemmId};
}

} // namespace mlir
Expand Down
61 changes: 46 additions & 15 deletions mlir/include/mlir/Dialect/MIOpen/Tuning/GridwiseGemmParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,8 @@ class PopulateParamsBase {
input1GemmKVectorizable = true;
}
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdDataOpType) {
// When K is the fastest changing dimension(3),
// gemmK dimension is vectorizable, gemmM is not, and vice versa.
// Vectorization width depending on length of K.
if (dimIndexVal["k"].first == 4) {
input1GemmKVectorizable = true;
} else {
input1GemmKVectorizable = false;
}
// always load gemmM first
input1GemmKVectorizable = false;
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdWeightOpType) {
// When K is the fastest changing dimension,
// gemmM dimension is vectorizable, gemmK is not, and vice versa.
Expand Down Expand Up @@ -200,6 +194,23 @@ class PopulateParamsBase {
}
}

static void obtainBwdDataFilterVecLen(ConvolutionContext &ctx,
int64_t &vecLen) {
auto dimIndexVal = ctx.dimIndexVal;
// Vectorization length logic is the same for forward and bwd_data
if (dimIndexVal["c"].first == 4) {
vecLen = dimIndexVal["c"].second;
} else if (dimIndexVal["c"].first == 2) {
// C's position is at 2, vectorization legnth depend last two dimension
if (dimIndexVal["y"].second == 1 && dimIndexVal["x"].second == 1) {
vecLen = dimIndexVal["c"].second;
} else {
vecLen = 1;
}
} else {
vecLen = 1;
}
}
static void obtainInputVecLen(ConvolutionContext &ctx, int64_t &vecLen) {
auto dimIndexVal = ctx.dimIndexVal;
if (dimIndexVal["ni"].first == 4) {
Expand All @@ -216,6 +227,26 @@ class PopulateParamsBase {
vecLen = 1;
}
}
static void obtainBwdDataOutputVecLen(ConvolutionContext &ctx,
int64_t &vecLen) {
auto dimIndexVal = ctx.dimIndexVal;
if (dimIndexVal["ko"].first == 4) {
vecLen = dimIndexVal["ko"].second;
} else if (dimIndexVal["no"].first == 4) {
vecLen = dimIndexVal["no"].second;
} else if (dimIndexVal["no"].first == 0) {
if (dimIndexVal["ho"].first == 3 && dimIndexVal["wo"].first == 4) {
if (dimIndexVal["y"].second == 1 && dimIndexVal["x"].second == 1)
vecLen = dimIndexVal["ho"].second * dimIndexVal["wo"].second;
else
vecLen = 1;
} else
vecLen = 1;
} else {
vecLen = 1;
}
}

static void obtainOutputVecLen(ConvolutionContext &ctx, int64_t &vecLen) {
auto dimIndexVal = ctx.dimIndexVal;
if (dimIndexVal["ko"].first == 4) {
Expand Down Expand Up @@ -250,7 +281,7 @@ class PopulateParamsBase {
if (opType == mlir::miopen::ConvOpType::Conv2DOpType) {
obtainFilterVecLen(ctx, vecLen);
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdDataOpType) {
obtainFilterVecLen(ctx, vecLen);
obtainBwdDataFilterVecLen(ctx, vecLen);
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdWeightOpType) {
obtainOutputVecLen(ctx, vecLen);
}
Expand All @@ -261,7 +292,7 @@ class PopulateParamsBase {
if (opType == mlir::miopen::ConvOpType::Conv2DOpType) {
obtainInputVecLen(ctx, vecLen);
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdDataOpType) {
obtainOutputVecLen(ctx, vecLen);
obtainBwdDataOutputVecLen(ctx, vecLen);
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdWeightOpType) {
obtainInputVecLen(ctx, vecLen);
}
Expand Down Expand Up @@ -385,8 +416,8 @@ class PopulateParamsBase {
auto gcdStrideDilationH = math::gcd(strideH, dilationH);
auto gcdStrideDilationW = math::gcd(strideW, dilationW);

auto yTilda = dilationH / gcdStrideDilationH;
auto xTilda = dilationW / gcdStrideDilationW;
auto yTilda = strideH / gcdStrideDilationH;
auto xTilda = strideW / gcdStrideDilationW;

auto hTilda =
ho + math::integer_divide_ceil(dilationH * (y - 1), strideH);
Expand All @@ -406,9 +437,9 @@ class PopulateParamsBase {
auto hTildaSlice = iHTildaRight - iHTildaLeft;
auto wTildaSlice = iWTildaRight - iWTildaLeft;

auto gemm_id = 0;
auto iYTilda = gemm_id / xTilda;
auto iXTilda = gemm_id % xTilda;
auto gemmId = ctx.gemmId;
auto iYTilda = gemmId / xTilda;
auto iXTilda = gemmId % xTilda;
auto yDotSlice = math::integer_divide_ceil(y - iYTilda, yTilda);
auto xDotSlice = math::integer_divide_ceil(x - iXTilda, xTilda);

Expand Down
87 changes: 53 additions & 34 deletions mlir/lib/Dialect/MIOpen/Generator/Conv2dGenerator.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "mlir/Dialect/MIOpen/Generator/Conv2dGenerator.h"
#include "mlir/Dialect/MIOpen/utility/math.hpp"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
Expand Down Expand Up @@ -66,7 +67,7 @@ int Conv2dGenerator::getKernelCount() const {
} else if (config.operation == "conv2d") {
count = 1;
} else if (config.operation == "conv2d_bwd_data") {
count = 1;
count = getBwdDataKernelCount();
} else if (config.operation == "conv2d_bwd_weight") {
count = 1;
} else if (config.operation == "conv2d_dummy") {
Expand All @@ -75,6 +76,31 @@ int Conv2dGenerator::getKernelCount() const {
return count;
}

int Conv2dGenerator::getBwdDataKernelCount() const {
auto gcdStrideDilationH =
math::gcd(config.strideHeight, config.dilationHeight);
auto gcdStrideDilationW = math::gcd(config.strideWidth, config.dilationWidth);

auto yTilda = config.strideHeight / gcdStrideDilationH;
auto xTilda = config.strideWidth / gcdStrideDilationW;

auto y = config.filterHeight;
auto x = config.filterWidth;
int count = 0;
for (int gemmId = 0; gemmId < yTilda * xTilda; gemmId++) {
// gemm_k size is different for each GEMM
const auto iYTilda = gemmId / xTilda;
const auto iXTilda = gemmId % xTilda;

auto yDotSlice = math::integer_divide_ceil(y - iYTilda, yTilda);
auto xDotSlice = math::integer_divide_ceil(x - iXTilda, xTilda);
// gemmK must > 0, otherwise not need to run
if (yDotSlice * xDotSlice > 0)
count++;
}

return count;
}
Type Conv2dGenerator::getDataType(OpBuilder &builder) const {
mlir::Type dataType;
if (config.dataTypeStr == "f32" || config.dataTypeStr == "fp32") {
Expand Down Expand Up @@ -168,7 +194,8 @@ Conv2dGenerator::parseConvDims(int64_t batchSize, int64_t groupSize,
int64_t inputWidth, int64_t outputChannel,
int64_t outputHeight, int64_t outputWidth,
int64_t filterHeight, int64_t filterWidth) {

config.filterHeight = filterHeight;
config.filterWidth = filterWidth;
static const std::string filterKeys = "kgcyx";
int64_t filterVals[] = {outputChannel / groupSize, groupSize,
inputChannel / groupSize, filterHeight, filterWidth};
Expand Down Expand Up @@ -227,7 +254,7 @@ LogicalResult Conv2dGenerator::genConvModule(ModuleOp &module,
OpBuilder &builder,
int kernel_id) {
if (kernel_id == -1) {
kernel_id = config.kernelId;
kernel_id = std::max(config.kernelId, 0);
}

Type dataType = getDataType(builder);
Expand Down Expand Up @@ -284,6 +311,7 @@ LogicalResult Conv2dGenerator::genConvModule(ModuleOp &module,
}

std::vector<NamedAttribute> attributes{
builder.getNamedAttr("gemm_id", builder.getI32IntegerAttr(kernel_id)),
builder.getNamedAttr("arch", builder.getStringAttr(config.arch)),
builder.getNamedAttr("num_cu", builder.getI32IntegerAttr(config.num_cu)),

Expand Down Expand Up @@ -323,43 +351,34 @@ LogicalResult Conv2dGenerator::genConvModule(ModuleOp &module,
attributes.push_back(
builder.getNamedAttr("xdlopsV2", builder.getBoolAttr(true)));

if (kernel_id > 0) {
if (config.operation == "conv2d") {
auto convOp = builder.create<miopen::Conv2DOp>(
builder.getUnknownLoc(), ArrayRef<mlir::Type>{},
ValueRange{func.getArgument(0), func.getArgument(1),
func.getArgument(2)},
attributes);
block->push_front(convOp);
} else if (config.operation == "conv2d_bwd_data") {
auto convOp = builder.create<miopen::Conv2DBwdDataOp>(
builder.getUnknownLoc(), ArrayRef<mlir::Type>{},
ValueRange{func.getArgument(0), func.getArgument(1),
func.getArgument(2)},
attributes);
block->push_front(convOp);
} else if (config.operation == "conv2d_bwd_weight") {
auto convOp = builder.create<miopen::Conv2DBwdWeightOp>(
builder.getUnknownLoc(), ArrayRef<mlir::Type>{},
ValueRange{func.getArgument(0), func.getArgument(1),
func.getArgument(2)},
attributes);
block->push_back(convOp);
} else if (config.operation == "conv2d_dummy") {
auto convOp = builder.create<miopen::Conv2DDummyOp>(
builder.getUnknownLoc(), ArrayRef<mlir::Type>{},
ValueRange{func.getArgument(0), func.getArgument(1),
func.getArgument(2)},
attributes);
block->push_front(convOp);
} else {
if (config.operation == "conv2d") {
auto convOp = builder.create<miopen::Conv2DOp>(
builder.getUnknownLoc(), ArrayRef<mlir::Type>{},
ValueRange{func.getArgument(0), func.getArgument(1),
func.getArgument(2)},
attributes);
block->push_front(convOp);
} else if (config.operation == "conv2d_bwd_data") {
auto convOp = builder.create<miopen::Conv2DBwdDataOp>(
builder.getUnknownLoc(), ArrayRef<mlir::Type>{},
ValueRange{func.getArgument(0), func.getArgument(1),
func.getArgument(2)},
attributes);
block->push_front(convOp);
} else if (config.operation == "conv2d_bwd_weight") {
auto convOp = builder.create<miopen::Conv2DBwdWeightOp>(
builder.getUnknownLoc(), ArrayRef<mlir::Type>{},
ValueRange{func.getArgument(0), func.getArgument(1),
func.getArgument(2)},
attributes);
block->push_back(convOp);
} else if (config.operation == "conv2d_dummy") {
auto convOp = builder.create<miopen::Conv2DDummyOp>(
builder.getUnknownLoc(), ArrayRef<mlir::Type>{},
ValueRange{func.getArgument(0), func.getArgument(1),
func.getArgument(2)},
attributes);
block->push_front(convOp);
}
}

auto returnOp =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ func @miopen_conv2d_bwd_data_ckyx_cnhw_knhw(%filter : memref<1x8x128x3x3xf32>, %
output_layout = ["go", "ko", "no", "ho", "wo"],
dilations = [1, 1],
strides = [1, 1],
padding = [0, 0, 0, 0]
padding = [0, 0, 0, 0],
gemm_id = 0
} : memref<1x8x128x3x3xf32>, memref<1x8x128x32x32xf32>, memref<1x128x128x30x30xf32>
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ func @miopen_conv2d_bwd_data_cyxk_chwn_khwn(%filter : memref<1x8x3x3x128xf32>, %
output_layout = ["go", "ko", "ho", "wo", "no"],
dilations = [1, 1],
strides = [1, 1],
padding = [0, 0, 0, 0]
padding = [0, 0, 0, 0],
gemm_id = 0
} : memref<1x8x3x3x128xf32>, memref<1x8x32x32x128xf32>, memref<1x128x30x30x128xf32>
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ func @miopen_conv2d_bwd_data_cyxk_cnhw_knhw(%filter : memref<1x8x3x3x128xf32>, %
output_layout = ["go", "ko", "no", "ho", "wo"],
dilations = [1, 1],
strides = [1, 1],
padding = [0, 0, 0, 0]
padding = [0, 0, 0, 0],
gemm_id = 0
} : memref<1x8x3x3x128xf32>, memref<1x8x128x32x32xf32>, memref<1x128x128x30x30xf32>
return
}
Expand Down
Loading

0 comments on commit 22c4eb7

Please sign in to comment.