Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backward data #246

Merged
merged 23 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/MIOpen/Generator/Conv2dGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class Conv2dGenerator {
SmallVector<int64_t, 5> filterDimension;
SmallVector<int64_t, 5> inputDimension;
SmallVector<int64_t, 5> outputDimension;

int filterHeight;
int filterWidth;
};

Conv2dGenerator(const std::string &arch = "", int num_cu = 0,
Expand Down Expand Up @@ -98,6 +101,7 @@ class Conv2dGenerator {
});
return permutation;
}
int getBwdDataKernelCount() const;

// Generator config
Config config;
Expand Down
30 changes: 26 additions & 4 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -1914,7 +1914,7 @@ struct Conv2DRewritePattern : public OpRewritePattern<T> {

LogicalResult backwardData(T op, PatternRewriter &b) const {
auto loc = op.getLoc();

auto gemmIdAttr = op->template getAttrOfType<IntegerAttr>("gemm_id");
auto archAttr = op->template getAttrOfType<StringAttr>("arch");
auto numCuAttr = op->template getAttrOfType<IntegerAttr>("num_cu");

Expand Down Expand Up @@ -2007,8 +2007,8 @@ struct Conv2DRewritePattern : public OpRewritePattern<T> {
auto gcdStrideDilationH = math::gcd(strideH, dilationH);
auto gcdStrideDilationW = math::gcd(strideW, dilationW);

auto yTilda = dilationH / gcdStrideDilationH;
auto xTilda = dilationW / gcdStrideDilationW;
auto yTilda = strideH / gcdStrideDilationH;
auto xTilda = strideW / gcdStrideDilationW;
whchung marked this conversation as resolved.
Show resolved Hide resolved

auto yDot = math::integer_divide_ceil(y, yTilda);
auto xDot = math::integer_divide_ceil(x, xTilda);
Expand All @@ -2029,7 +2029,28 @@ struct Conv2DRewritePattern : public OpRewritePattern<T> {
auto hTildaSlice = iHTildaRight - iHTildaLeft;
auto wTildaSlice = iWTildaRight - iWTildaLeft;

auto gemmId = 0;
auto getGemmId = [&](int kernelId) {
// kernelId 0 must be gemmId 0
if (kernelId <= 0)
return 0;

llvm::SmallVector<int> gemmIds;
for (int gemmId = 0; gemmId < yTilda * xTilda; gemmId++) {
// gemm_k size is different for each GEMM
const auto iYTilda = gemmId / xTilda;
const auto iXTilda = gemmId % xTilda;

auto yDotSlice = math::integer_divide_ceil(y - iYTilda, yTilda);
auto xDotSlice = math::integer_divide_ceil(x - iXTilda, xTilda);
// gemmK must > 0, otherwise not need to run
if (yDotSlice * xDotSlice > 0) {
gemmIds.push_back(gemmId);
}
}
assert(gemmIds.size() > kernelId);
return gemmIds[kernelId];
};
auto gemmId = getGemmId(gemmIdAttr.getInt());
auto iYTilda = gemmId / xTilda;
auto iXTilda = gemmId % xTilda;
auto yDotSlice = math::integer_divide_ceil(y - iYTilda, yTilda);
Expand Down Expand Up @@ -3279,6 +3300,7 @@ struct Conv2DRewritePattern : public OpRewritePattern<T> {

// Set attributes for gridwise_gemm op.
llvm::SmallVector<NamedAttribute, 8> gridwiseGemmAttrs{
b.getNamedAttr("gemm_id", gemmIdAttr),
b.getNamedAttr("arch", archAttr),
b.getNamedAttr("num_cu", numCuAttr),
b.getNamedAttr("filter_layout", filterLayoutAttr),
Expand Down
15 changes: 11 additions & 4 deletions mlir/include/mlir/Dialect/MIOpen/Tuning/ConvContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@ struct ConvolutionContext : SQLiteSerializable<ConvolutionContext> {
llvm::SmallVector<int64_t, 0> strideVal;
llvm::SmallVector<int64_t, 0> dilationVal;
llvm::SmallVector<int64_t, 0> paddingVal;
int gemmId;

ConvolutionContext(const llvm::SmallString<8> &architecture, int numCu,
miopen::ConvOpType op,
llvm::StringMap<std::pair<size_t, int64_t>> dim,
llvm::SmallVector<int64_t, 0> stride,
llvm::SmallVector<int64_t, 0> dilation,
llvm::SmallVector<int64_t, 0> padding)
llvm::SmallVector<int64_t, 0> padding, int gemmid)
: arch(architecture), num_cu(numCu), opType(op), dimIndexVal(dim),
strideVal(stride), dilationVal(dilation), paddingVal(padding) {}
strideVal(stride), dilationVal(dilation), paddingVal(padding),
gemmId(gemmid) {}

llvm::StringMap<std::pair<size_t, int64_t>> getDimIndexVal() const {
return dimIndexVal;
Expand Down Expand Up @@ -146,6 +148,11 @@ template <typename T> static ConvolutionContext populateConvContext(T &op) {

auto archVal = op->template getAttrOfType<StringAttr>("arch").getValue();
int numCuVal = op->template getAttrOfType<IntegerAttr>("num_cu").getInt();
auto gemmIdAttr = op->template getAttrOfType<IntegerAttr>("gemm_id");
int gemmId = 0;
if (gemmIdAttr) {
gemmId = gemmIdAttr.getInt();
}

llvm::StringMap<std::pair<size_t, int64_t>> dimIndexVal;

Expand Down Expand Up @@ -176,8 +183,8 @@ template <typename T> static ConvolutionContext populateConvContext(T &op) {
llvm::SmallVector<int64_t, 0> paddingVal;
populateSeqVal(paddingAttr, paddingVal);

return {archVal, numCuVal, opType, dimIndexVal,
strideVal, dilationVal, paddingVal};
return {archVal, numCuVal, opType, dimIndexVal,
strideVal, dilationVal, paddingVal, gemmId};
}

} // namespace mlir
Expand Down
61 changes: 46 additions & 15 deletions mlir/include/mlir/Dialect/MIOpen/Tuning/GridwiseGemmParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,8 @@ class PopulateParamsBase {
input1GemmKVectorizable = true;
}
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdDataOpType) {
// When K is the fastest changing dimension(3),
// gemmK dimension is vectorizable, gemmM is not, and vice versa.
// Vectorization width depending on length of K.
if (dimIndexVal["k"].first == 4) {
input1GemmKVectorizable = true;
} else {
input1GemmKVectorizable = false;
}
// always load gemmM first
input1GemmKVectorizable = false;
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdWeightOpType) {
// When K is the fastest changing dimension,
// gemmM dimension is vectorizable, gemmK is not, and vice versa.
Expand Down Expand Up @@ -200,6 +194,23 @@ class PopulateParamsBase {
}
}

static void obtainBwdDataFilterVecLen(ConvolutionContext &ctx,
whchung marked this conversation as resolved.
Show resolved Hide resolved
int64_t &vecLen) {
auto dimIndexVal = ctx.dimIndexVal;
// Vectorization length logic is the same for forward and bwd_data
if (dimIndexVal["c"].first == 4) {
vecLen = dimIndexVal["c"].second;
} else if (dimIndexVal["c"].first == 2) {
// C's position is at 2, vectorization legnth depend last two dimension
if (dimIndexVal["y"].second == 1 && dimIndexVal["x"].second == 1) {
vecLen = dimIndexVal["c"].second;
} else {
vecLen = 1;
}
} else {
vecLen = 1;
}
}
static void obtainInputVecLen(ConvolutionContext &ctx, int64_t &vecLen) {
auto dimIndexVal = ctx.dimIndexVal;
if (dimIndexVal["ni"].first == 4) {
Expand All @@ -216,6 +227,26 @@ class PopulateParamsBase {
vecLen = 1;
}
}
static void obtainBwdDataOutputVecLen(ConvolutionContext &ctx,
int64_t &vecLen) {
auto dimIndexVal = ctx.dimIndexVal;
if (dimIndexVal["ko"].first == 4) {
vecLen = dimIndexVal["ko"].second;
} else if (dimIndexVal["no"].first == 4) {
vecLen = dimIndexVal["no"].second;
} else if (dimIndexVal["no"].first == 0) {
if (dimIndexVal["ho"].first == 3 && dimIndexVal["wo"].first == 4) {
if (dimIndexVal["y"].second == 1 && dimIndexVal["x"].second == 1)
asleepzzz marked this conversation as resolved.
Show resolved Hide resolved
vecLen = dimIndexVal["ho"].second * dimIndexVal["wo"].second;
else
vecLen = 1;
} else
vecLen = 1;
} else {
vecLen = 1;
}
}

static void obtainOutputVecLen(ConvolutionContext &ctx, int64_t &vecLen) {
auto dimIndexVal = ctx.dimIndexVal;
if (dimIndexVal["ko"].first == 4) {
Expand Down Expand Up @@ -250,7 +281,7 @@ class PopulateParamsBase {
if (opType == mlir::miopen::ConvOpType::Conv2DOpType) {
obtainFilterVecLen(ctx, vecLen);
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdDataOpType) {
obtainFilterVecLen(ctx, vecLen);
obtainBwdDataFilterVecLen(ctx, vecLen);
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdWeightOpType) {
obtainOutputVecLen(ctx, vecLen);
}
Expand All @@ -261,7 +292,7 @@ class PopulateParamsBase {
if (opType == mlir::miopen::ConvOpType::Conv2DOpType) {
obtainInputVecLen(ctx, vecLen);
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdDataOpType) {
obtainOutputVecLen(ctx, vecLen);
obtainBwdDataOutputVecLen(ctx, vecLen);
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdWeightOpType) {
obtainInputVecLen(ctx, vecLen);
}
Expand Down Expand Up @@ -385,8 +416,8 @@ class PopulateParamsBase {
auto gcdStrideDilationH = math::gcd(strideH, dilationH);
auto gcdStrideDilationW = math::gcd(strideW, dilationW);

auto yTilda = dilationH / gcdStrideDilationH;
auto xTilda = dilationW / gcdStrideDilationW;
auto yTilda = strideH / gcdStrideDilationH;
auto xTilda = strideW / gcdStrideDilationW;
whchung marked this conversation as resolved.
Show resolved Hide resolved

auto hTilda =
ho + math::integer_divide_ceil(dilationH * (y - 1), strideH);
Expand All @@ -406,9 +437,9 @@ class PopulateParamsBase {
auto hTildaSlice = iHTildaRight - iHTildaLeft;
auto wTildaSlice = iWTildaRight - iWTildaLeft;

auto gemm_id = 0;
auto iYTilda = gemm_id / xTilda;
auto iXTilda = gemm_id % xTilda;
auto gemmId = ctx.gemmId;
auto iYTilda = gemmId / xTilda;
auto iXTilda = gemmId % xTilda;
auto yDotSlice = math::integer_divide_ceil(y - iYTilda, yTilda);
auto xDotSlice = math::integer_divide_ceil(x - iXTilda, xTilda);

Expand Down
87 changes: 53 additions & 34 deletions mlir/lib/Dialect/MIOpen/Generator/Conv2dGenerator.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "mlir/Dialect/MIOpen/Generator/Conv2dGenerator.h"
#include "mlir/Dialect/MIOpen/utility/math.hpp"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
Expand Down Expand Up @@ -66,7 +67,7 @@ int Conv2dGenerator::getKernelCount() const {
} else if (config.operation == "conv2d") {
count = 1;
} else if (config.operation == "conv2d_bwd_data") {
count = 1;
count = getBwdDataKernelCount();
} else if (config.operation == "conv2d_bwd_weight") {
count = 1;
} else if (config.operation == "conv2d_dummy") {
Expand All @@ -75,6 +76,31 @@ int Conv2dGenerator::getKernelCount() const {
return count;
}

int Conv2dGenerator::getBwdDataKernelCount() const {
auto gcdStrideDilationH =
math::gcd(config.strideHeight, config.dilationHeight);
auto gcdStrideDilationW = math::gcd(config.strideWidth, config.dilationWidth);

auto yTilda = config.strideHeight / gcdStrideDilationH;
auto xTilda = config.strideWidth / gcdStrideDilationW;

auto y = config.filterHeight;
auto x = config.filterWidth;
int count = 0;
for (int gemmId = 0; gemmId < yTilda * xTilda; gemmId++) {
jerryyin marked this conversation as resolved.
Show resolved Hide resolved
// gemm_k size is different for each GEMM
const auto iYTilda = gemmId / xTilda;
const auto iXTilda = gemmId % xTilda;

auto yDotSlice = math::integer_divide_ceil(y - iYTilda, yTilda);
auto xDotSlice = math::integer_divide_ceil(x - iXTilda, xTilda);
// gemmK must > 0, otherwise not need to run
if (yDotSlice * xDotSlice > 0)
count++;
}

return count;
}
Type Conv2dGenerator::getDataType(OpBuilder &builder) const {
mlir::Type dataType;
if (config.dataTypeStr == "f32" || config.dataTypeStr == "fp32") {
Expand Down Expand Up @@ -168,7 +194,8 @@ Conv2dGenerator::parseConvDims(int64_t batchSize, int64_t groupSize,
int64_t inputWidth, int64_t outputChannel,
int64_t outputHeight, int64_t outputWidth,
int64_t filterHeight, int64_t filterWidth) {

config.filterHeight = filterHeight;
config.filterWidth = filterWidth;
static const std::string filterKeys = "kgcyx";
int64_t filterVals[] = {outputChannel / groupSize, groupSize,
inputChannel / groupSize, filterHeight, filterWidth};
Expand Down Expand Up @@ -227,7 +254,7 @@ LogicalResult Conv2dGenerator::genConvModule(ModuleOp &module,
OpBuilder &builder,
int kernel_id) {
if (kernel_id == -1) {
kernel_id = config.kernelId;
kernel_id = std::max(config.kernelId, 0);
}

Type dataType = getDataType(builder);
Expand Down Expand Up @@ -284,6 +311,7 @@ LogicalResult Conv2dGenerator::genConvModule(ModuleOp &module,
}

std::vector<NamedAttribute> attributes{
builder.getNamedAttr("gemm_id", builder.getI32IntegerAttr(kernel_id)),
builder.getNamedAttr("arch", builder.getStringAttr(config.arch)),
builder.getNamedAttr("num_cu", builder.getI32IntegerAttr(config.num_cu)),

Expand Down Expand Up @@ -323,43 +351,34 @@ LogicalResult Conv2dGenerator::genConvModule(ModuleOp &module,
attributes.push_back(
builder.getNamedAttr("xdlopsV2", builder.getBoolAttr(true)));

if (kernel_id > 0) {
if (config.operation == "conv2d") {
auto convOp = builder.create<miopen::Conv2DOp>(
builder.getUnknownLoc(), ArrayRef<mlir::Type>{},
ValueRange{func.getArgument(0), func.getArgument(1),
func.getArgument(2)},
attributes);
block->push_front(convOp);
} else if (config.operation == "conv2d_bwd_data") {
auto convOp = builder.create<miopen::Conv2DBwdDataOp>(
builder.getUnknownLoc(), ArrayRef<mlir::Type>{},
ValueRange{func.getArgument(0), func.getArgument(1),
func.getArgument(2)},
attributes);
block->push_front(convOp);
} else if (config.operation == "conv2d_bwd_weight") {
auto convOp = builder.create<miopen::Conv2DBwdWeightOp>(
builder.getUnknownLoc(), ArrayRef<mlir::Type>{},
ValueRange{func.getArgument(0), func.getArgument(1),
func.getArgument(2)},
attributes);
block->push_back(convOp);
} else if (config.operation == "conv2d_dummy") {
whchung marked this conversation as resolved.
Show resolved Hide resolved
auto convOp = builder.create<miopen::Conv2DDummyOp>(
builder.getUnknownLoc(), ArrayRef<mlir::Type>{},
ValueRange{func.getArgument(0), func.getArgument(1),
func.getArgument(2)},
attributes);
block->push_front(convOp);
} else {
if (config.operation == "conv2d") {
auto convOp = builder.create<miopen::Conv2DOp>(
builder.getUnknownLoc(), ArrayRef<mlir::Type>{},
ValueRange{func.getArgument(0), func.getArgument(1),
func.getArgument(2)},
attributes);
block->push_front(convOp);
} else if (config.operation == "conv2d_bwd_data") {
auto convOp = builder.create<miopen::Conv2DBwdDataOp>(
builder.getUnknownLoc(), ArrayRef<mlir::Type>{},
ValueRange{func.getArgument(0), func.getArgument(1),
func.getArgument(2)},
attributes);
block->push_front(convOp);
} else if (config.operation == "conv2d_bwd_weight") {
auto convOp = builder.create<miopen::Conv2DBwdWeightOp>(
builder.getUnknownLoc(), ArrayRef<mlir::Type>{},
ValueRange{func.getArgument(0), func.getArgument(1),
func.getArgument(2)},
attributes);
block->push_back(convOp);
} else if (config.operation == "conv2d_dummy") {
auto convOp = builder.create<miopen::Conv2DDummyOp>(
builder.getUnknownLoc(), ArrayRef<mlir::Type>{},
ValueRange{func.getArgument(0), func.getArgument(1),
func.getArgument(2)},
attributes);
block->push_front(convOp);
}
}

auto returnOp =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ func @miopen_conv2d_bwd_data_ckyx_cnhw_knhw(%filter : memref<1x8x128x3x3xf32>, %
output_layout = ["go", "ko", "no", "ho", "wo"],
dilations = [1, 1],
strides = [1, 1],
padding = [0, 0, 0, 0]
padding = [0, 0, 0, 0],
gemm_id = 0
} : memref<1x8x128x3x3xf32>, memref<1x8x128x32x32xf32>, memref<1x128x128x30x30xf32>
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ func @miopen_conv2d_bwd_data_cyxk_chwn_khwn(%filter : memref<1x8x3x3x128xf32>, %
output_layout = ["go", "ko", "ho", "wo", "no"],
dilations = [1, 1],
strides = [1, 1],
padding = [0, 0, 0, 0]
padding = [0, 0, 0, 0],
gemm_id = 0
} : memref<1x8x3x3x128xf32>, memref<1x8x32x32x128xf32>, memref<1x128x30x30x128xf32>
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ func @miopen_conv2d_bwd_data_cyxk_cnhw_knhw(%filter : memref<1x8x3x3x128xf32>, %
output_layout = ["go", "ko", "no", "ho", "wo"],
dilations = [1, 1],
strides = [1, 1],
padding = [0, 0, 0, 0]
padding = [0, 0, 0, 0],
gemm_id = 0
} : memref<1x8x3x3x128xf32>, memref<1x8x128x32x32xf32>, memref<1x128x128x30x30xf32>
return
}
Expand Down
Loading