From 5a30f4f668b4dab320e62156eccc9ad632b9c4f6 Mon Sep 17 00:00:00 2001 From: "kevin.chang" Date: Wed, 9 Jun 2021 18:27:08 +0000 Subject: [PATCH 01/11] enable fwd, wrw padding kernel --- .../mlir/Dialect/MIOpen/LowerMIOpenOps.h | 299 +++++++++++------- .../MIOpen/Transforms/AffineTransforms.cpp | 6 +- .../MIOpen/Tuning/GridwiseGemmParams.cpp | 14 +- 3 files changed, 199 insertions(+), 120 deletions(-) diff --git a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h index f441b273da28..3e9ce561e36e 100644 --- a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h +++ b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h @@ -1537,11 +1537,25 @@ struct Conv2DRewritePattern : public OpRewritePattern { SmallString<8> gemmMPad_name("gemmMPad"); SmallString<8> gemmNPad_name("gemmNPad"); + int64_t nonGemmMSize = transformedFilterShape[1]; + int64_t gemmMSize = transformedFilterShape[2]; // filter pad start // filter : K & CRS , if CRS is under 64 or 32 // we pad CRS to 32 or 64, then mlir can do gemm // we add more one transform to do pad - if (convOpType == miopen::ConvOpType::Conv2DOpType && gemmKExtra > 0) { + bool filterCheckPadGemmM = false; + bool filterCheckPadGemmK = false; + bool filterCheckPadGemmN = false; + filterCheckPadGemmM = + (convOpType == miopen::ConvOpType::Conv2DOpType && gemmMExtra > 0) || + (convOpType == miopen::ConvOpType::Conv2DBwdWeightOpType && + gemmMExtra > 0); + filterCheckPadGemmK = + (convOpType == miopen::ConvOpType::Conv2DOpType && gemmKExtra > 0); + filterCheckPadGemmN = + (convOpType == miopen::ConvOpType::Conv2DBwdWeightOpType && + gemmNExtra > 0); + if (filterCheckPadGemmM || filterCheckPadGemmK || filterCheckPadGemmN) { StringAttr gemmDim0TargetName = b.getStringAttr(arg0TargetLayoutName0); StringAttr gemmDim1TargetName; StringAttr gemmDim2TargetName; @@ -1567,9 +1581,6 @@ struct Conv2DRewritePattern : public OpRewritePattern { paddingFilterShape.push_back(transformedFilterShape[1]); paddingFilterShape.push_back(transformedFilterShape[2]); - StringAttr gemmKDim; - IntegerAttr gemmKDimName; - llvm::SmallVector sourceGemmDim0Attr{ b.getNamedAttr("transformation", b.getStringAttr("PassThrough")), b.getNamedAttr("lower_layer_dimensions", b.getArrayAttr({GemmDim0})), @@ -1594,13 +1605,14 @@ struct Conv2DRewritePattern : public OpRewritePattern { b.getNamedAttr("upper_layer_dimensions", b.getArrayAttr({GemmDim2}))}; // gemmdim0 is G, only pad gemmdim1 and gemmdim2 - if (gemmKExtra > 0) { + if (filterCheckPadGemmK) { if (arg0TargetLayoutName1 == "gemmK") { isFilterPad = true; isGemmDim1Pad = true; gemmDim1TargetName = b.getStringAttr(gemmKPad_name); - paddingFilterShape[1] = paddingFilterShape[1] + gemmKExtra; + // fwd + paddingFilterShape[1] = nonGemmMSize + gemmKExtra; sourceGemmDim1Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); sourceGemmDim1Attr.push_back( @@ -1612,51 +1624,79 @@ struct Conv2DRewritePattern : public OpRewritePattern { targetGemmDim1Attr.push_back( b.getNamedAttr("upper_layer_names", b.getArrayAttr({b.getStringAttr(gemmKPad_name)}))); - } else if (arg0TargetLayoutName2 == "gemmK") { + } + // filter of forward, gemmK=c*y*x + filterOobCheckDims.insert(nameToDims["c"]); + filterOobCheckDims.insert(nameToDims["y"]); + filterOobCheckDims.insert(nameToDims["x"]); + } + + if (filterCheckPadGemmM) { + if (arg0TargetLayoutName1 == "gemmM") { + // wrw isFilterPad = true; - isGemmDim2Pad = true; - gemmDim2TargetName = b.getStringAttr(gemmKPad_name); + isGemmDim1Pad = true; + gemmDim1TargetName = b.getStringAttr(gemmMPad_name); + // even dim1 name is gemmM ,the size of dim 2 is gemmM + paddingFilterShape[2] = gemmMSize + gemmMExtra; - paddingFilterShape[2] = paddingFilterShape[2] + gemmKExtra; + sourceGemmDim1Attr.push_back( + b.getNamedAttr("transformation", b.getStringAttr("Pad"))); + sourceGemmDim1Attr.push_back( + b.getNamedAttr("parameters", b.getArrayAttr({ + b.getI32IntegerAttr(0), + b.getI32IntegerAttr(gemmMExtra), + }))); + + targetGemmDim1Attr.push_back(b.getNamedAttr( + "names", b.getArrayAttr({b.getStringAttr(gemmMPad_name)}))); + } else if (arg0TargetLayoutName2 == "gemmM") { + // fwd + isFilterPad = true; + isGemmDim2Pad = true; + gemmDim2TargetName = b.getStringAttr(gemmMPad_name); + paddingFilterShape[2] = gemmMSize + gemmMExtra; sourceGemmDim2Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); sourceGemmDim2Attr.push_back( b.getNamedAttr("parameters", b.getArrayAttr({ b.getI32IntegerAttr(0), - b.getI32IntegerAttr(gemmKExtra), + b.getI32IntegerAttr(gemmMExtra), }))); - targetGemmDim2Attr.push_back( - b.getNamedAttr("upper_layer_names", - b.getArrayAttr({b.getStringAttr(gemmKPad_name)}))); + targetGemmDim2Attr.push_back(b.getNamedAttr( + "names", b.getArrayAttr({b.getStringAttr(gemmMPad_name)}))); } - // filter of forward, gemmK=c*y*x - filterOobCheckDims.insert(nameToDims["c"]); - filterOobCheckDims.insert(nameToDims["y"]); - filterOobCheckDims.insert(nameToDims["x"]); - } - if (gemmMExtra > 0) { - if (arg0TargetLayoutName1 == "gemmM") { - isFilterPad = false; - isGemmDim1Pad = false; - paddingFilterShape[1] = paddingFilterShape[1] + gemmMExtra; - } else if (arg0TargetLayoutName2 == "gemmM") { - isFilterPad = false; - isGemmDim2Pad = false; - paddingFilterShape[2] = paddingFilterShape[2] + gemmMExtra; - } + filterOobCheckDims.insert(nameToDims["k"]); } - if (gemmNExtra > 0) { - if (arg0TargetLayoutName1 == "gemmN") { - isFilterPad = false; - isGemmDim1Pad = false; - paddingFilterShape[1] = paddingFilterShape[1] + gemmNExtra; - } else if (arg0TargetLayoutName2 == "gemmN") { - isFilterPad = false; - isGemmDim2Pad = false; - paddingFilterShape[2] = paddingFilterShape[2] + gemmNExtra; + if (filterCheckPadGemmN) { + if (arg0TargetLayoutName2 == "gemmN") { + // wrw + isFilterPad = true; + isGemmDim2Pad = true; + gemmDim2TargetName = b.getStringAttr(gemmNPad_name); + paddingFilterShape[1] = nonGemmMSize + gemmNExtra; + sourceGemmDim2Attr.push_back( + b.getNamedAttr("transformation", b.getStringAttr("Pad"))); + sourceGemmDim2Attr.push_back( + b.getNamedAttr("parameters", b.getArrayAttr({ + b.getI32IntegerAttr(0), + b.getI32IntegerAttr(gemmNExtra), + }))); + + targetGemmDim2Attr.push_back(b.getNamedAttr( + "names", b.getArrayAttr({b.getStringAttr(gemmNPad_name)}))); + } + // FIXME: if we set every dim in merge transformation to store oob, + // can't pass verification, but only set top dim , it's ok + if (filterYDim == 2) { + // kyxc + filterOobCheckDims.insert(nameToDims["y"]); + } else { + // kcyx + filterOobCheckDims.insert(nameToDims["c"]); } } @@ -2325,7 +2365,19 @@ struct Conv2DRewritePattern : public OpRewritePattern { // input : NHW & CRS , if CRS is under 64 or 32 // we pad CRS to 32 or 64, then mlir can do gemm // we add more one transform to do pad - if (convOpType == miopen::ConvOpType::Conv2DOpType && gemmKExtra > 0) { + + // input do not GemmM when forward and backward weights + bool inputCheckPadGemmK = false; + bool inputCheckPadGemmN = false; + inputCheckPadGemmK = + (convOpType == miopen::ConvOpType::Conv2DOpType && gemmKExtra > 0) || + (convOpType == miopen::ConvOpType::Conv2DBwdWeightOpType && + gemmKExtra > 0); + inputCheckPadGemmN = + (convOpType == miopen::ConvOpType::Conv2DOpType && gemmNExtra > 0) || + (convOpType == miopen::ConvOpType::Conv2DBwdWeightOpType && + gemmNExtra > 0); + if (inputCheckPadGemmK || inputCheckPadGemmN) { llvm::SmallVector paddingInputShape; llvm::SmallVector paddingInputAttrs; @@ -2377,10 +2429,11 @@ struct Conv2DRewritePattern : public OpRewritePattern { llvm::SmallVector targetGemmDim2Attr{ b.getNamedAttr("upper_layer_dimensions", b.getArrayAttr({GemmDim2}))}; - if (gemmKExtra > 0) { + if (inputCheckPadGemmK) { if (arg1TargetLayoutName1 == "gemmK") { isInputPad = true; isGemmDim1Pad = true; + // fwd is cyx ,wrw is nhw gemmDim1TargetName = b.getStringAttr(gemmKPad_name); paddingInputShape[1] = paddingInputShape[1] + gemmKExtra; @@ -2392,49 +2445,44 @@ struct Conv2DRewritePattern : public OpRewritePattern { targetGemmDim1Attr.push_back( b.getNamedAttr("upper_layer_names", b.getArrayAttr({b.getStringAttr(gemmKPad_name)}))); - } else if (arg1TargetLayoutName2 == "gemmK") { + + if (convOpType == miopen::ConvOpType::Conv2DOpType) { + inputOobCheckDims.insert(nameToDims["ci"]); + } else if (convOpType == miopen::ConvOpType::Conv2DBwdWeightOpType) { + inputOobCheckDims.insert(nameToDims["ni"]); + } + + inputOobCheckDims.insert(nameToDims["hi"]); + inputOobCheckDims.insert(nameToDims["wi"]); + } + } + + if (inputCheckPadGemmN) { + if (arg1TargetLayoutName2 == "gemmN") { isInputPad = true; isGemmDim2Pad = true; - gemmDim2TargetName = b.getStringAttr(gemmKPad_name); + gemmDim2TargetName = b.getStringAttr(gemmNPad_name); + paddingInputShape[2] = paddingInputShape[2] + gemmNExtra; - paddingInputShape[2] = paddingInputShape[2] + gemmKExtra; + paddingInputShape[2] = paddingInputShape[2] + gemmNExtra; sourceGemmDim2Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); sourceGemmDim2Attr.push_back(b.getNamedAttr( "parameters", b.getArrayAttr({b.getI32IntegerAttr(0), - b.getI32IntegerAttr(gemmKExtra)}))); - + b.getI32IntegerAttr(gemmNExtra)}))); targetGemmDim2Attr.push_back( b.getNamedAttr("upper_layer_names", - b.getArrayAttr({b.getStringAttr(gemmKPad_name)}))); - } - // input of forward, gemmK = ci * y * x( y x from hi wi) - inputOobCheckDims.insert(nameToDims["ci"]); - inputOobCheckDims.insert(nameToDims["hi"]); - inputOobCheckDims.insert(nameToDims["wi"]); - } - - if (gemmMExtra > 0) { - if (arg1TargetLayoutName1 == "gemmM") { - isInputPad = false; - isGemmDim1Pad = false; - paddingInputShape[1] = paddingInputShape[1] + gemmMExtra; - } else if (arg1TargetLayoutName2 == "gemmM") { - isInputPad = false; - isGemmDim2Pad = false; - paddingInputShape[2] = paddingInputShape[2] + gemmMExtra; - } - } - - if (gemmNExtra > 0) { - if (arg1TargetLayoutName1 == "gemmN") { - isInputPad = false; - isGemmDim1Pad = false; - paddingInputShape[1] = paddingInputShape[1] + gemmNExtra; - } else if (arg1TargetLayoutName2 == "gemmN") { - isInputPad = false; - isGemmDim2Pad = false; - paddingInputShape[2] = paddingInputShape[2] + gemmNExtra; + b.getArrayAttr({b.getStringAttr(gemmNPad_name)}))); + + if (convOpType == miopen::ConvOpType::Conv2DOpType) { + inputOobCheckDims.insert(nameToDims["ni"]); + inputOobCheckDims.insert(nameToDims["hi"]); + inputOobCheckDims.insert(nameToDims["wi"]); + } else if (convOpType == miopen::ConvOpType::Conv2DBwdWeightOpType) { + inputOobCheckDims.insert(nameToDims["ci"]); + inputOobCheckDims.insert(nameToDims["hi"]); + inputOobCheckDims.insert(nameToDims["wi"]); + } } } @@ -2680,8 +2728,20 @@ struct Conv2DRewritePattern : public OpRewritePattern { // If Nhw is under 32 or 64 ,we pad it to 32 or 64 // then mlir can do gemm // we just add more one transform to do it - if (convOpType == miopen::ConvOpType::Conv2DBwdWeightOpType && - gemmKExtra > 0) { + + bool outputCheckPadGemmK = false; + bool outputCheckPadGemmM = false; + bool outputCheckPadGemmN = false; + outputCheckPadGemmK = + (convOpType == miopen::ConvOpType::Conv2DBwdWeightOpType && + gemmKExtra > 0); + outputCheckPadGemmM = + (convOpType == miopen::ConvOpType::Conv2DBwdWeightOpType && + gemmMExtra > 0) || + (convOpType == miopen::ConvOpType::Conv2DOpType && gemmMExtra > 0); + outputCheckPadGemmN = + (convOpType == miopen::ConvOpType::Conv2DOpType && gemmNExtra > 0); + if (outputCheckPadGemmK || outputCheckPadGemmM || outputCheckPadGemmN) { StringAttr gemmDim0TargetName = b.getStringAttr(arg2TargetLayoutName0); StringAttr gemmDim1TargetName; StringAttr gemmDim2TargetName; @@ -2707,9 +2767,6 @@ struct Conv2DRewritePattern : public OpRewritePattern { paddingOutputShape.push_back(transformedOutputShape[1]); paddingOutputShape.push_back(transformedOutputShape[2]); - StringAttr gemmKDim; - IntegerAttr gemmKDimName; - llvm::SmallVector sourceGemmDim0Attr{ b.getNamedAttr("transformation", b.getStringAttr("PassThrough")), b.getNamedAttr("lower_layer_dimensions", b.getArrayAttr({GemmDim0})), @@ -2733,12 +2790,12 @@ struct Conv2DRewritePattern : public OpRewritePattern { llvm::SmallVector targetGemmDim2Attr{ b.getNamedAttr("upper_layer_dimensions", b.getArrayAttr({GemmDim2}))}; - if (gemmKExtra > 0) { + if (outputCheckPadGemmK) { if (arg2TargetLayoutName1 == "gemmK") { isOutputPad = true; isGemmDim1Pad = true; gemmDim1TargetName = b.getStringAttr(gemmKPad_name); - + // wrw dim 1 is nhw paddingOutputShape[1] = paddingOutputShape[1] + gemmKExtra; sourceGemmDim1Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); @@ -2749,49 +2806,65 @@ struct Conv2DRewritePattern : public OpRewritePattern { targetGemmDim1Attr.push_back( b.getNamedAttr("upper_layer_names", b.getArrayAttr({b.getStringAttr(gemmKPad_name)}))); - } else if (arg2TargetLayoutName2 == "gemmK") { - isOutputPad = true; - isGemmDim2Pad = true; - gemmDim2TargetName = b.getStringAttr(gemmKPad_name); - - paddingOutputShape[2] = paddingOutputShape[2] + gemmKExtra; - sourceGemmDim2Attr.push_back( - b.getNamedAttr("transformation", b.getStringAttr("Pad"))); - sourceGemmDim2Attr.push_back(b.getNamedAttr( - "parameters", b.getArrayAttr({b.getI32IntegerAttr(0), - b.getI32IntegerAttr(gemmKExtra)}))); - - targetGemmDim2Attr.push_back( - b.getNamedAttr("upper_layer_names", - b.getArrayAttr({b.getStringAttr(gemmKPad_name)}))); + // wrw + outputOobCheckDims.insert(nameToDims["no"]); + outputOobCheckDims.insert(nameToDims["ho"]); + outputOobCheckDims.insert(nameToDims["wo"]); } - // output of forward, gemmK = no * ho * wo - outputOobCheckDims.insert(nameToDims["no"]); - outputOobCheckDims.insert(nameToDims["ho"]); - outputOobCheckDims.insert(nameToDims["wo"]); } - if (gemmMExtra > 0) { + if (outputCheckPadGemmM) { if (arg2TargetLayoutName1 == "gemmM") { - isOutputPad = false; - isGemmDim1Pad = false; + isOutputPad = true; + isGemmDim1Pad = true; + gemmDim1TargetName = b.getStringAttr(gemmMPad_name); + // fwd, k paddingOutputShape[1] = paddingOutputShape[1] + gemmMExtra; + sourceGemmDim1Attr.push_back( + b.getNamedAttr("transformation", b.getStringAttr("Pad"))); + sourceGemmDim1Attr.push_back(b.getNamedAttr( + "parameters", b.getArrayAttr({b.getI32IntegerAttr(0), + b.getI32IntegerAttr(gemmMExtra)}))); + + targetGemmDim1Attr.push_back(b.getNamedAttr( + "names", b.getArrayAttr({b.getStringAttr(gemmMPad_name)}))); + outputOobCheckDims.insert(nameToDims["ko"]); } else if (arg2TargetLayoutName2 == "gemmM") { - isOutputPad = false; - isGemmDim2Pad = false; + isOutputPad = true; + isGemmDim2Pad = true; + gemmDim2TargetName = b.getStringAttr(gemmMPad_name); + // wrw ,k paddingOutputShape[2] = paddingOutputShape[2] + gemmMExtra; + sourceGemmDim2Attr.push_back( + b.getNamedAttr("transformation", b.getStringAttr("Pad"))); + sourceGemmDim2Attr.push_back(b.getNamedAttr( + "parameters", b.getArrayAttr({b.getI32IntegerAttr(0), + b.getI32IntegerAttr(gemmMExtra)}))); + + targetGemmDim2Attr.push_back(b.getNamedAttr( + "names", b.getArrayAttr({b.getStringAttr(gemmMPad_name)}))); + outputOobCheckDims.insert(nameToDims["ko"]); } } - if (gemmNExtra > 0) { - if (arg2TargetLayoutName1 == "gemmN") { - isOutputPad = false; - isGemmDim1Pad = false; - paddingOutputShape[1] = paddingOutputShape[1] + gemmNExtra; - } else if (arg2TargetLayoutName2 == "gemmN") { - isOutputPad = false; - isGemmDim2Pad = false; + if (outputCheckPadGemmN) { + if (arg2TargetLayoutName2 == "gemmN") { + // fwd dim 2,nhw + isOutputPad = true; + isGemmDim2Pad = true; + gemmDim2TargetName = b.getStringAttr(gemmNPad_name); paddingOutputShape[2] = paddingOutputShape[2] + gemmNExtra; + sourceGemmDim2Attr.push_back( + b.getNamedAttr("transformation", b.getStringAttr("Pad"))); + sourceGemmDim2Attr.push_back(b.getNamedAttr( + "parameters", b.getArrayAttr({b.getI32IntegerAttr(0), + b.getI32IntegerAttr(gemmNExtra)}))); + + targetGemmDim2Attr.push_back(b.getNamedAttr( + "names", b.getArrayAttr({b.getStringAttr(gemmNPad_name)}))); + // FIXME: to set dim in merge transormation to oob store, + // set only top dim or you will get zero values + outputOobCheckDims.insert(nameToDims["no"]); } } diff --git a/mlir/lib/Dialect/MIOpen/Transforms/AffineTransforms.cpp b/mlir/lib/Dialect/MIOpen/Transforms/AffineTransforms.cpp index 54078fcb3661..1fb7f1ce6677 100644 --- a/mlir/lib/Dialect/MIOpen/Transforms/AffineTransforms.cpp +++ b/mlir/lib/Dialect/MIOpen/Transforms/AffineTransforms.cpp @@ -55,6 +55,9 @@ AffineMap AffineTransforms::buildIndexAffineMap(miopen::TransformOp op) { } else if (transformAttr.getValue() == "Pad") { assert(srcDimAttr.size() == destDimAttr.size()); + // FIXME: wrw padding kernel, rank of memref is not the same with + // padding transform, so close it now until issue fixing + bool closeRightPadding = true; auto parameters = dimLayoutAttr.get("parameters").cast(); for (unsigned j = 0; j < srcDimAttr.size(); ++j) { // example of h and w pad parameters [0, 2, 3, 1] : @@ -72,7 +75,8 @@ AffineMap AffineTransforms::buildIndexAffineMap(miopen::TransformOp op) { auto destDim = destDimAttr.getValue()[j].cast().getInt(); auto expr = getAffineDimExpr(destDim, op.getContext()) + getAffineConstantExpr(-leftPad, op.getContext()); - if (leftPad == 0 && rightPad != 0) { + + if (leftPad == 0 && rightPad != 0 & !closeRightPadding) { // when leftPad == 0 , your original expr is just minus leftpad, but // leftpad is zero, affinemap do not have minus out of boundary // check depends on minus symbol , it will not do out of boundary diff --git a/mlir/lib/Dialect/MIOpen/Tuning/GridwiseGemmParams.cpp b/mlir/lib/Dialect/MIOpen/Tuning/GridwiseGemmParams.cpp index 64b0aa10fd65..de23d499a498 100644 --- a/mlir/lib/Dialect/MIOpen/Tuning/GridwiseGemmParams.cpp +++ b/mlir/lib/Dialect/MIOpen/Tuning/GridwiseGemmParams.cpp @@ -177,8 +177,7 @@ LogicalResult PopulateParams::paramsFromCtx( << " PARAMETERS!\n"); InitParams paddingParam = getUniversalParameters(); - if ((gemmSize.gemmN % paddingParam.gemmNPerBlock == 0) && - (gemmSize.gemmM % paddingParam.gemmMPerBlock == 0)) { + if (ctx.opType != miopen::ConvOpType::Conv2DBwdDataOpType) { LLVM_DEBUG(llvm::dbgs() << "BUT PADDING KERNEL CAN EXECUTE IT\n"); for (auto ¶ms : initParameters) { @@ -194,7 +193,9 @@ LogicalResult PopulateParams::paramsFromCtx( break; } } else { - LLVM_DEBUG(llvm::dbgs() << "PADDING KERNEL only support gemmK now\n"); + LLVM_DEBUG( + llvm::dbgs() + << "PADDING KERNEL only support forward, backward weights now\n"); } } else { LLVM_DEBUG(llvm::dbgs() << "Successfully picked tuning params from backup" @@ -380,8 +381,7 @@ LogicalResult PopulateParamsXDL::paramsFromCtx( << " PARAMETERS!\n"); InitParams paddingParam = getUniversalParameters(); - if ((gemmSize.gemmN % paddingParam.gemmNPerBlock == 0) && - (gemmSize.gemmM % paddingParam.gemmMPerBlock == 0)) { + if (ctx.opType != miopen::ConvOpType::Conv2DBwdDataOpType) { LLVM_DEBUG(llvm::dbgs() << "BUT PADDING KERNEL CAN EXECUTE IT\n"); for (auto ¶ms : initParameters) { res = populatePaddingKernelDerived(ctx, params, gemmSize, @@ -395,7 +395,9 @@ LogicalResult PopulateParamsXDL::paramsFromCtx( break; } } else { - LLVM_DEBUG(llvm::dbgs() << "PADDING KERNEL only support gemmK now\n"); + LLVM_DEBUG( + llvm::dbgs() + << "PADDING KERNEL only support forward, backward weights now\n"); } } else { LLVM_DEBUG(llvm::dbgs() << "Successfully picked tuning params from backup" From 562502e346b9fc4c6adc8b18c2190a29db6028ea Mon Sep 17 00:00:00 2001 From: "kevin.chang" Date: Wed, 9 Jun 2021 20:18:03 +0000 Subject: [PATCH 02/11] fix mlir --- ...lowering_filter_tensor_ckyx_cnhw_knhw.mlir | 2 +- ...gridwise_gemm_position_cyxk_chwn_khwn.mlir | 4 ++- .../lowering_input_tensor_cyxk_cnhw_knhw.mlir | 4 +++ .../lowering_memref_kcyx_nchw_nkhw.mlir | 12 ++++---- .../Dialect/MIOpen/lowering_top_level.mlir | 30 +++++++++++-------- mlir/test/mlir-miopen-driver/padding_map.mlir | 17 ----------- 6 files changed, 33 insertions(+), 36 deletions(-) delete mode 100644 mlir/test/mlir-miopen-driver/padding_map.mlir diff --git a/mlir/test/Dialect/MIOpen/lowering_filter_tensor_ckyx_cnhw_knhw.mlir b/mlir/test/Dialect/MIOpen/lowering_filter_tensor_ckyx_cnhw_knhw.mlir index ec227b46f073..71b8ed42d3fe 100644 --- a/mlir/test/Dialect/MIOpen/lowering_filter_tensor_ckyx_cnhw_knhw.mlir +++ b/mlir/test/Dialect/MIOpen/lowering_filter_tensor_ckyx_cnhw_knhw.mlir @@ -74,4 +74,4 @@ func @miopen_conv2d_bwd_weight_ckyx_cnhw_knhw(%filter : memref<1x8x128x3x3xf32>, // CHECK: upper_layer_names = ["gemmM"] // CHECK: lower_layer_names = ["c", "y", "x"] // CHECK: upper_layer_names = ["gemmN"] -// CHECK-NEXT: miopen.transform(%arg1) +// CHECK-NEXT: miopen.transform diff --git a/mlir/test/Dialect/MIOpen/lowering_gridwise_gemm_position_cyxk_chwn_khwn.mlir b/mlir/test/Dialect/MIOpen/lowering_gridwise_gemm_position_cyxk_chwn_khwn.mlir index 178c6d23f5a9..687dd7aaf464 100644 --- a/mlir/test/Dialect/MIOpen/lowering_gridwise_gemm_position_cyxk_chwn_khwn.mlir +++ b/mlir/test/Dialect/MIOpen/lowering_gridwise_gemm_position_cyxk_chwn_khwn.mlir @@ -71,7 +71,9 @@ func @miopen_conv2d_bwd_weight_cyxk_chwn_khwn(%filter : memref<1x8x3x3x128xf32>, // CHECK-NEXT: miopen.transform // CHECK-NEXT: miopen.transform // CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform // CHECK: gridwise_gemm_argument_position = 1 // CHECK-NEXT: miopen.transform +// CHECK-NEXT: miopen.transform // CHECK: gridwise_gemm_argument_position = 0 -// CHECK-NEXT: miopen.gridwise_gemm(%4, %3, %0) +// CHECK-NEXT: miopen.gridwise_gemm(%6, %5, %1) diff --git a/mlir/test/Dialect/MIOpen/lowering_input_tensor_cyxk_cnhw_knhw.mlir b/mlir/test/Dialect/MIOpen/lowering_input_tensor_cyxk_cnhw_knhw.mlir index 6434a3fb51be..ca04911d6cb2 100644 --- a/mlir/test/Dialect/MIOpen/lowering_input_tensor_cyxk_cnhw_knhw.mlir +++ b/mlir/test/Dialect/MIOpen/lowering_input_tensor_cyxk_cnhw_knhw.mlir @@ -66,10 +66,14 @@ func @miopen_conv2d_bwd_weight_cyxk_cnhw_knhw(%filter : memref<1x8x3x3x128xf32>, } // CHECK-LABEL: func @miopen_conv2d_bwd_weight // CHECK-NEXT: miopen.transform(%arg0) +// CHECK-NEXT: miopen.transform +// CHECK: upper_layer_layout = ["gemmG", "gemmM", "gemmNPad"] // CHECK-NEXT: miopen.transform(%arg1) // CHECK: upper_layer_layout = ["gi", "ci", "ni", "hipad", "wipad"] // CHECK-NEXT: miopen.transform // CHECK: upper_layer_layout = ["gi", "ci", "ni", "y", "ho", "x", "wo"] // CHECK-NEXT: miopen.transform // CHECK: upper_layer_layout = ["gemmG", "gemmK", "gemmN"] +// CHECK-NEXT: miopen.transform +// CHECK: upper_layer_layout = ["gemmG", "gemmK", "gemmNPad"] // CHECK-NEXT: miopen.transform(%arg2) diff --git a/mlir/test/Dialect/MIOpen/lowering_memref_kcyx_nchw_nkhw.mlir b/mlir/test/Dialect/MIOpen/lowering_memref_kcyx_nchw_nkhw.mlir index 3e75201b0547..f932b23b622a 100644 --- a/mlir/test/Dialect/MIOpen/lowering_memref_kcyx_nchw_nkhw.mlir +++ b/mlir/test/Dialect/MIOpen/lowering_memref_kcyx_nchw_nkhw.mlir @@ -66,9 +66,11 @@ func @miopen_conv2d_bwd_weight_kcyx_nchw_nkhw(%filter : memref<1x128x8x3x3xf32>, return } // CHECK-LABEL: func @miopen_conv2d_bwd_weight -// CHECK-NEXT: {{miopen.transform.*{.*}.*memref.*memref}} -// CHECK-NEXT: {{miopen.transform.*{.*}.*memref.*memref}} -// CHECK-NEXT: {{miopen.transform.*{.*}.*memref.*memref}} -// CHECK-NEXT: {{miopen.transform.*{.*}.*memref.*memref}} -// CHECK-NEXT: {{miopen.transform.*{.*}.*memref.*memref}} +// CHECK-NEXT: {{miopen.transform.*{.*"g", "k", "c", "y", "x".*}.*memref.*memref}} +// CHECK-NEXT: {{miopen.transform.*{.*"gemmG", "gemmM", "gemmNPad".*}.*memref.*memref}} +// CHECK-NEXT: {{miopen.transform.*{.*"ni", "gi", "ci", "hipad", "wipad".*}.*memref.*memref}} +// CHECK-NEXT: {{miopen.transform.*{.*"ni", "gi", "ci", "y", "ho", "x", "wo".*}.*memref.*memref}} +// CHECK-NEXT: {{miopen.transform.*{.*"gemmG", "gemmK", "gemmN".*}.*memref.*memref}} +// CHECK-NEXT: {{miopen.transform.*{.*"gemmG", "gemmK", "gemmNPad".*}.*memref.*memref}} +// CHECK-NEXT: {{miopen.transform.*{.*"gemmG", "gemmK", "gemmM".*}.*memref.*memref}} // CHECK-NEXT: {{miopen.gridwise_gemm.*{.*}.*memref.*memref.*memref}} diff --git a/mlir/test/Dialect/MIOpen/lowering_top_level.mlir b/mlir/test/Dialect/MIOpen/lowering_top_level.mlir index 9ea979bcc037..da6870ca619e 100644 --- a/mlir/test/Dialect/MIOpen/lowering_top_level.mlir +++ b/mlir/test/Dialect/MIOpen/lowering_top_level.mlir @@ -107,7 +107,7 @@ func @miopen_conv2d_bwd_data_f16(%filter : memref<1x128x8x3x3xf16>, %input : mem // CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmK", "gemmN"\].*}} // CHECK-NEXT: miopen.gridwise_gemm -func @miopen_conv2d_bwd_weight(%filter : memref<1x128x8x3x3xf32>, %input : memref<128x1x8x32x32xf32>, %output : memref<128x1x128x30x30xf32>) { +func @miopen_conv2d_bwd_weight(%filter : memref<1x20x8x3x3xf32>, %input : memref<7x1x8x32x32xf32>, %output : memref<7x1x20x30x30xf32>) { miopen.conv2d_bwd_weight(%filter, %input, %output) { arch = "gfx906", num_cu = 64, @@ -117,19 +117,22 @@ func @miopen_conv2d_bwd_weight(%filter : memref<1x128x8x3x3xf32>, %input : memre dilations = [1, 1], strides = [1, 1], padding = [0, 0, 0 ,0] - } : memref<1x128x8x3x3xf32>, memref<128x1x8x32x32xf32>, memref<128x1x128x30x30xf32> + } : memref<1x20x8x3x3xf32>, memref<7x1x8x32x32xf32>, memref<7x1x20x30x30xf32> return } // CHECK-LABEL: func {{@miopen_conv2d_bwd_weight.*%arg0.*%arg1.*%arg2}} -// CHECK-NOT: miopen.conv2d_bwd_data -// CHECK-NEXT: miopen.transform(%arg0) +// CHECK-NOT: miopen.conv2d_bwd_weight +// CHECK-NEXT: {{miopen.transform\(%arg0\).* upper_layer_layout = \["gemmG", "gemmM", "gemmN"\].*}} +// CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmMPad", "gemmNPad"\].*}} // CHECK-NEXT: miopen.transform(%arg1) // CHECK-NEXT: miopen.transform -// CHECK-NEXT: miopen.transform -// CHECK-NEXT: miopen.transform(%arg2) +// CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmK", "gemmN"\].*}} +// CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmKPad", "gemmNPad"\].*}} +// CHECK-NEXT: {{miopen.transform\(%arg2\).* upper_layer_layout = \["gemmG", "gemmK", "gemmM"\].*}} +// CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmKPad", "gemmMPad"\].*}} // CHECK-NEXT: miopen.gridwise_gemm -func @miopen_conv2d_bwd_weight_f16(%filter : memref<1x128x8x3x3xf16>, %input : memref<128x1x8x32x32xf16>, %output : memref<128x1x128x30x30xf16>) { +func @miopen_conv2d_bwd_weight_f16(%filter : memref<1x20x8x3x3xf16>, %input : memref<7x1x8x32x32xf16>, %output : memref<7x1x20x30x30xf16>) { miopen.conv2d_bwd_weight(%filter, %input, %output) { arch = "gfx906", num_cu = 64, @@ -139,14 +142,17 @@ func @miopen_conv2d_bwd_weight_f16(%filter : memref<1x128x8x3x3xf16>, %input : m dilations = [1, 1], strides = [1, 1], padding = [0, 0, 0 ,0] - } : memref<1x128x8x3x3xf16>, memref<128x1x8x32x32xf16>, memref<128x1x128x30x30xf16> + } : memref<1x20x8x3x3xf16>, memref<7x1x8x32x32xf16>, memref<7x1x20x30x30xf16> return } // CHECK-LABEL: func {{@miopen_conv2d_bwd_weight.*%arg0.*%arg1.*%arg2}} -// CHECK-NOT: miopen.conv2d_bwd_data -// CHECK-NEXT: miopen.transform(%arg0) +// CHECK-NOT: miopen.conv2d_bwd_weight_f16 +// CHECK-NEXT: {{miopen.transform\(%arg0\).* upper_layer_layout = \["gemmG", "gemmM", "gemmN"\].*}} +// CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmMPad", "gemmNPad"\].*}} // CHECK-NEXT: miopen.transform(%arg1) // CHECK-NEXT: miopen.transform -// CHECK-NEXT: miopen.transform -// CHECK-NEXT: miopen.transform(%arg2) +// CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmK", "gemmN"\].*}} +// CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmKPad", "gemmNPad"\].*}} +// CHECK-NEXT: {{miopen.transform\(%arg2\).* upper_layer_layout = \["gemmG", "gemmK", "gemmM"\].*}} +// CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmKPad", "gemmMPad"\].*}} // CHECK-NEXT: miopen.gridwise_gemm diff --git a/mlir/test/mlir-miopen-driver/padding_map.mlir b/mlir/test/mlir-miopen-driver/padding_map.mlir deleted file mode 100644 index a3ac56b5b563..000000000000 --- a/mlir/test/mlir-miopen-driver/padding_map.mlir +++ /dev/null @@ -1,17 +0,0 @@ -// This tests checks the affinemap component: -// RUN: mlir-opt -miopen-lowering -miopen-affine-transform %s | FileCheck %s - - module { - func @pad_parameters_1101(%arg0: memref<256x1x128x28x28xf32>) attributes {kernel} { - %1 = miopen.transform(%arg0) {extraPad = false, gemmK_extra = 0 : i32, gemmN_extra = 0 : i32, layout = [{upper_layer_dimensions = [1 : i32], upper_layer_names = ["gi"], lower_layer_dimensions = [1 : i32], lower_layer_names = ["gi"], transformation = "PassThrough"}, {upper_layer_dimensions = [0 : i32], upper_layer_names = ["ni"], lower_layer_dimensions = [0 : i32], lower_layer_names = ["ni"], transformation = "PassThrough"}, {upper_layer_dimensions = [2 : i32], upper_layer_names = ["ci"], lower_layer_dimensions = [2 : i32], lower_layer_names = ["ci"], transformation = "PassThrough"}, {upper_layer_dimensions = [3 : i32, 4 : i32], upper_layer_names = ["hipad", "wipad"], parameters = [1 : i32, 1 : i32, 0 : i32, 1 : i32], lower_layer_dimensions = [3 : i32, 4 : i32], lower_layer_names = ["hi", "wi"], transformation = "Pad"}], upper_layer_layout = ["ni", "gi", "ci", "hipad", "wipad"], lower_layer_layout = ["ni", "gi", "ci", "hi", "wi"]} : memref<256x1x128x28x28xf32> to memref<256x1x128x28x28xf32> - return - } - -// CHECK: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3 - 1, (d4 + 2) ceildiv 29 + d4 - 1)> - - func @pad_parameters_00(%arg0: memref<256x1x67x28x28xf32>) attributes {kernel} { - %1 = miopen.transform(%arg0) {extraPad = true, gemmK_extra = 61 : i32, gemmN_extra = 0 : i32, layout = [{upper_layer_dimensions = [1 : i32], upper_layer_names = ["gi"], lower_layer_dimensions = [1 : i32], lower_layer_names = ["gi"], transformation = "PassThrough"}, {upper_layer_dimensions = [0 : i32], upper_layer_names = ["ni"], lower_layer_dimensions = [0 : i32], lower_layer_names = ["ni"], transformation = "PassThrough"}, {upper_layer_dimensions = [2 : i32], upper_layer_names = ["ci"], parameters = [0 : i32, 61 : i32],lower_layer_dimensions = [2 : i32], lower_layer_names = ["ci"], transformation = "Pad"}, {upper_layer_dimensions = [3 : i32, 4 : i32], upper_layer_names = ["hipad", "wipad"], parameters = [0 : i32, 0 : i32, 0 : i32, 0 : i32], lower_layer_dimensions = [3 : i32, 4 : i32], lower_layer_names = ["hi", "wi"], transformation = "Pad"}], upper_layer_layout = ["ni", "gi", "ci", "hipad", "wipad"], lower_layer_layout = ["ni", "gi", "ci", "hi", "wi"]} : memref<256x1x67x28x28xf32> to memref<256x1x128x28x28xf32> - return - } -// CHECK: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, (d2 + 2) ceildiv 68 + d2 - 1, d3, d4)> - } From 8ebb78794d9a6c28f19378dde7e93de4f179494d Mon Sep 17 00:00:00 2001 From: "kevin.chang" Date: Wed, 9 Jun 2021 20:53:38 +0000 Subject: [PATCH 03/11] add unit test --- .../auto_e2e/padding_kernel_all_fwd.mlir | 21 +++++++++++++++++++ .../auto_e2e/padding_kernel_all_wrw.mlir | 21 +++++++++++++++++++ .../auto_e2e/padding_kernel_gemmN.mlir | 21 +++++++++++++++++++ 3 files changed, 63 insertions(+) create mode 100644 mlir/test/mlir-miopen-driver/auto_e2e/padding_kernel_all_fwd.mlir create mode 100644 mlir/test/mlir-miopen-driver/auto_e2e/padding_kernel_all_wrw.mlir create mode 100644 mlir/test/mlir-miopen-driver/auto_e2e/padding_kernel_gemmN.mlir diff --git a/mlir/test/mlir-miopen-driver/auto_e2e/padding_kernel_all_fwd.mlir b/mlir/test/mlir-miopen-driver/auto_e2e/padding_kernel_all_fwd.mlir new file mode 100644 index 000000000000..d2c080a168b5 --- /dev/null +++ b/mlir/test/mlir-miopen-driver/auto_e2e/padding_kernel_all_fwd.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-miopen-driver --operation=conv2d -t f32 -p=false -fil_layout=gkcyx -in_layout=ngchw -out_layout=ngkhw -batchsize=20 -groupsize=1 -in_channels=3 -out_channels=6 -in_h=32 -in_w=32 -fil_h=7 -fil_w=7 --dilation_h=1 --dilation_w=1 --padding_h=3 --padding_w=3 --conv_stride_h=2 --conv_stride_w=2 -pv %random_data %xdlops -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET50_CONFIG1 + +// CHECK_RESNET50_CONFIG1: Unranked Memref base@ = 0x{{.*}} rank = 1 offset = 0 sizes = [1] strides = [1] data = +// CHECK_RESNET50_CONFIG1: [1] + +// RUN: mlir-miopen-driver --operation=conv2d -t f16 -p=false -fil_layout=gkcyx -in_layout=ngchw -out_layout=ngkhw -batchsize=20 -groupsize=1 -in_channels=3 -out_channels=6 -in_h=32 -in_w=32 -fil_h=7 -fil_w=7 --dilation_h=1 --dilation_w=1 --padding_h=3 --padding_w=3 --conv_stride_h=2 --conv_stride_w=2 -pv %random_data %xdlops -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET50_CONFIG2 + +// CHECK_RESNET50_CONFIG2: Unranked Memref base@ = 0x{{.*}} rank = 1 offset = 0 sizes = [1] strides = [1] data = +// CHECK_RESNET50_CONFIG2: [1] + +// RUN: mlir-miopen-driver --operation=conv2d -t f32 -p=false -fil_layout=gkyxc -in_layout=nhwgc -out_layout=nhwgk -batchsize=20 -groupsize=1 -in_channels=3 -out_channels=6 -in_h=32 -in_w=32 -fil_h=7 -fil_w=7 --dilation_h=1 --dilation_w=1 --padding_h=3 --padding_w=3 --conv_stride_h=2 --conv_stride_w=2 -pv %random_data %xdlops -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET50_CONFIG3 + +// CHECK_RESNET50_CONFIG3: Unranked Memref base@ = 0x{{.*}} rank = 1 offset = 0 sizes = [1] strides = [1] data = +// CHECK_RESNET50_CONFIG3: [1] + +// RUN: mlir-miopen-driver --operation=conv2d -t f16 -p=false -fil_layout=gkyxc -in_layout=nhwgc -out_layout=nhwgk -batchsize=20 -groupsize=1 -in_channels=3 -out_channels=6 -in_h=32 -in_w=32 -fil_h=7 -fil_w=7 --dilation_h=1 --dilation_w=1 --padding_h=3 --padding_w=3 --conv_stride_h=2 --conv_stride_w=2 -pv %random_data %xdlops -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET50_CONFIG4 + +// CHECK_RESNET50_CONFIG4: Unranked Memref base@ = 0x{{.*}} rank = 1 offset = 0 sizes = [1] strides = [1] data = +// CHECK_RESNET50_CONFIG4: [1] + + diff --git a/mlir/test/mlir-miopen-driver/auto_e2e/padding_kernel_all_wrw.mlir b/mlir/test/mlir-miopen-driver/auto_e2e/padding_kernel_all_wrw.mlir new file mode 100644 index 000000000000..76fcca32888d --- /dev/null +++ b/mlir/test/mlir-miopen-driver/auto_e2e/padding_kernel_all_wrw.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-miopen-driver --operation=conv2d_bwd_weight -t f32 -p=false -fil_layout=gkcyx -in_layout=ngchw -out_layout=ngkhw -batchsize=20 -groupsize=1 -in_channels=3 -out_channels=6 -in_h=32 -in_w=32 -fil_h=7 -fil_w=7 --dilation_h=1 --dilation_w=1 --padding_h=3 --padding_w=3 --conv_stride_h=2 --conv_stride_w=2 -pv %random_data %xdlops -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET50_CONFIG1 + +// CHECK_RESNET50_CONFIG1: Unranked Memref base@ = 0x{{.*}} rank = 1 offset = 0 sizes = [1] strides = [1] data = +// CHECK_RESNET50_CONFIG1: [1] + +// RUN: mlir-miopen-driver --operation=conv2d_bwd_weight -t f16 -p=false -fil_layout=gkcyx -in_layout=ngchw -out_layout=ngkhw -batchsize=20 -groupsize=1 -in_channels=3 -out_channels=6 -in_h=32 -in_w=32 -fil_h=7 -fil_w=7 --dilation_h=1 --dilation_w=1 --padding_h=3 --padding_w=3 --conv_stride_h=2 --conv_stride_w=2 -pv %random_data %xdlops -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET50_CONFIG2 + +// CHECK_RESNET50_CONFIG2: Unranked Memref base@ = 0x{{.*}} rank = 1 offset = 0 sizes = [1] strides = [1] data = +// CHECK_RESNET50_CONFIG2: [1] + +// RUN: mlir-miopen-driver --operation=conv2d_bwd_weight -t f32 -p=false -fil_layout=gkyxc -in_layout=nhwgc -out_layout=nhwgk -batchsize=20 -groupsize=1 -in_channels=3 -out_channels=6 -in_h=32 -in_w=32 -fil_h=7 -fil_w=7 --dilation_h=1 --dilation_w=1 --padding_h=3 --padding_w=3 --conv_stride_h=2 --conv_stride_w=2 -pv %random_data %xdlops -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET50_CONFIG3 + +// CHECK_RESNET50_CONFIG3: Unranked Memref base@ = 0x{{.*}} rank = 1 offset = 0 sizes = [1] strides = [1] data = +// CHECK_RESNET50_CONFIG3: [1] + +// RUN: mlir-miopen-driver --operation=conv2d_bwd_weight -t f16 -p=false -fil_layout=gkyxc -in_layout=nhwgc -out_layout=nhwgk -batchsize=20 -groupsize=1 -in_channels=3 -out_channels=6 -in_h=32 -in_w=32 -fil_h=7 -fil_w=7 --dilation_h=1 --dilation_w=1 --padding_h=3 --padding_w=3 --conv_stride_h=2 --conv_stride_w=2 -pv %random_data %xdlops -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET50_CONFIG4 + +// CHECK_RESNET50_CONFIG4: Unranked Memref base@ = 0x{{.*}} rank = 1 offset = 0 sizes = [1] strides = [1] data = +// CHECK_RESNET50_CONFIG4: [1] + + diff --git a/mlir/test/mlir-miopen-driver/auto_e2e/padding_kernel_gemmN.mlir b/mlir/test/mlir-miopen-driver/auto_e2e/padding_kernel_gemmN.mlir new file mode 100644 index 000000000000..98eaafde1173 --- /dev/null +++ b/mlir/test/mlir-miopen-driver/auto_e2e/padding_kernel_gemmN.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-miopen-driver --operationconv2d_bwd_weight -t f32 -p=false -fil_layout=gkcyx -in_layout=ngchw -out_layout=ngkhw -batchsize=256 -groupsize=1 -in_channels=3 -out_channels=64 -in_h=224 -in_w=224 -fil_h=7 -fil_w=7 --dilation_h=1 --dilation_w=1 --padding_h=3 --padding_w=3 --conv_stride_h=2 --conv_stride_w=2 -pv %random_data %xdlops -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET50_CONFIG1 + +// CHECK_RESNET50_CONFIG1: Unranked Memref base@ = 0x{{.*}} rank = 1 offset = 0 sizes = [1] strides = [1] data = +// CHECK_RESNET50_CONFIG1: [1] + +// RUN: mlir-miopen-driver --operationconv2d_bwd_weight -t f16 -p=false -fil_layout=gkcyx -in_layout=ngchw -out_layout=ngkhw -batchsize=256 -groupsize=1 -in_channels=3 -out_channels=64 -in_h=224 -in_w=224 -fil_h=7 -fil_w=7 --dilation_h=1 --dilation_w=1 --padding_h=3 --padding_w=3 --conv_stride_h=2 --conv_stride_w=2 -pv %random_data %xdlops -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET50_CONFIG2 + +// CHECK_RESNET50_CONFIG2: Unranked Memref base@ = 0x{{.*}} rank = 1 offset = 0 sizes = [1] strides = [1] data = +// CHECK_RESNET50_CONFIG2: [1] + +// RUN: mlir-miopen-driver --operationconv2d_bwd_weight -t f32 -p=false -fil_layout=gkyxc -in_layout=nhwgc -out_layout=nhwgk -batchsize=256 -groupsize=1 -in_channels=3 -out_channels=64 -in_h=224 -in_w=224 -fil_h=7 -fil_w=7 --dilation_h=1 --dilation_w=1 --padding_h=3 --padding_w=3 --conv_stride_h=2 --conv_stride_w=2 -pv %random_data %xdlops -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET50_CONFIG3 + +// CHECK_RESNET50_CONFIG3: Unranked Memref base@ = 0x{{.*}} rank = 1 offset = 0 sizes = [1] strides = [1] data = +// CHECK_RESNET50_CONFIG3: [1] + +// RUN: mlir-miopen-driver --operationconv2d_bwd_weight -t f16 -p=false -fil_layout=gkyxc -in_layout=nhwgc -out_layout=nhwgk -batchsize=256 -groupsize=1 -in_channels=3 -out_channels=64 -in_h=224 -in_w=224 -fil_h=7 -fil_w=7 --dilation_h=1 --dilation_w=1 --padding_h=3 --padding_w=3 --conv_stride_h=2 --conv_stride_w=2 -pv %random_data %xdlops -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET50_CONFIG4 + +// CHECK_RESNET50_CONFIG4: Unranked Memref base@ = 0x{{.*}} rank = 1 offset = 0 sizes = [1] strides = [1] data = +// CHECK_RESNET50_CONFIG4: [1] + + From e9c5048830746d14ae914703be3700a152390f56 Mon Sep 17 00:00:00 2001 From: "kevin.chang" Date: Thu, 10 Jun 2021 09:37:25 +0000 Subject: [PATCH 04/11] Revert "[HOTFIX] Temporarily disable index diff map when gemmKExtra > 0. (#261)" This reverts commit 3fd1610156b008d8f41722f40116262397ec5b00. --- .../mlir/Dialect/MIOpen/LowerMIOpenOps.h | 48 ++++--------------- 1 file changed, 8 insertions(+), 40 deletions(-) diff --git a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h index 3e9ce561e36e..015beef67d75 100644 --- a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h +++ b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h @@ -7169,12 +7169,6 @@ struct ThreadwiseCopyRewritePattern // false : use the faster index diff map. auto legacyLoadAttr = op->getAttr("legacy_load"); auto legacyStoreAttr = op->getAttr("legacy_store"); - bool legacyLoad = - (legacyLoadAttr && - legacyLoadAttr.template cast().getValue() == true); - bool legacyStore = - (legacyStoreAttr && - legacyStoreAttr.template cast().getValue() == true); Optional composedSourceTransform; Optional composedDestTransform; @@ -7205,36 +7199,6 @@ struct ThreadwiseCopyRewritePattern return failure(); } - // FIXME. XXX. - // Workaround to obtain gemmKExtra attribute. - // And use it to override legacy load/store debug switch. - auto overrideLoadStoreHack = - [](const DictionaryAttr &transformSpec) -> bool { - if (transformSpec) { - Attribute metadataAttr = transformSpec.get("metadata"); - if (metadataAttr) { - ArrayAttr layeredTransformMetadata = - metadataAttr.template cast(); - for (unsigned iter = 0; iter < layeredTransformMetadata.size(); - ++iter) { - DictionaryAttr dictAttr = - layeredTransformMetadata[iter].template cast(); - auto gemmKExtraAttr = dictAttr.get("gemmKExtra"); - if (gemmKExtraAttr) { - auto gemmKExtra = - gemmKExtraAttr.template cast().getInt(); - if (gemmKExtra > 0) { - return true; - } - } - } - } - } - return false; - }; - legacyLoad = overrideLoadStoreHack(srcTransformSpec); - legacyStore = overrideLoadStoreHack(destTransformSpec); - // Populate the vector to hold source and dest coordinate. SmallVector sourceCoord; SmallVector destCoord; @@ -7343,7 +7307,8 @@ struct ThreadwiseCopyRewritePattern // wthe the metadata. // Only do such computation in the new approach where index diff maps // would be used. - if (legacyLoad == false) { + if (!legacyLoadAttr || + (legacyLoadAttr.template cast().getValue() == false)) { // Populate coorindates across the layers of transformations. if (srcTransformSpec) { Attribute metadataAttr = srcTransformSpec.get("metadata"); @@ -7374,7 +7339,8 @@ struct ThreadwiseCopyRewritePattern // wthe the metadata. // Only do such computation in the new approach where index diff maps // would be used. - if (legacyStore == false) { + if (!legacyStoreAttr || + (legacyStoreAttr.template cast().getValue() == false)) { // Populate coorindates across the layers of transformations. if (destTransformSpec) { Attribute metadataAttr = destTransformSpec.get("metadata"); @@ -7415,7 +7381,8 @@ struct ThreadwiseCopyRewritePattern bool toExit = false; do { // Use the old logic in case "legacy_load" attribute is specified. - if (legacyLoad == true) { + if (legacyLoadAttr && + (legacyLoadAttr.template cast().getValue() == true)) { computeTopAndBottomIndicesWithAffineMap( b, loc, srcUpperIndices, srcLowerIndices, sourceCoord, loopIVsPerAccessOrder, dimAccessOrder, layeredSourceTransform); @@ -7438,7 +7405,8 @@ struct ThreadwiseCopyRewritePattern b, loc, scalarValue, sourceElementType, destElementType); // Use the old logic in case "legacy_store" attribute is specified. - if (legacyStore == true) { + if (legacyStoreAttr && + (legacyStoreAttr.template cast().getValue() == true)) { computeTopAndBottomIndicesWithAffineMap( b, loc, destUpperIndices, destLowerIndices, destCoord, loopIVsPerAccessOrder, dimAccessOrder, layeredDestTransform); From 116cf1da84e6c33a46d5ed118100c8f3e8f51a72 Mon Sep 17 00:00:00 2001 From: "kevin.chang" Date: Thu, 10 Jun 2021 09:40:05 +0000 Subject: [PATCH 05/11] remove affinemap of rightpad --- .../MIOpen/Transforms/AffineTransforms.cpp | 42 ------------------- 1 file changed, 42 deletions(-) diff --git a/mlir/lib/Dialect/MIOpen/Transforms/AffineTransforms.cpp b/mlir/lib/Dialect/MIOpen/Transforms/AffineTransforms.cpp index 1fb7f1ce6677..4aa1a160b211 100644 --- a/mlir/lib/Dialect/MIOpen/Transforms/AffineTransforms.cpp +++ b/mlir/lib/Dialect/MIOpen/Transforms/AffineTransforms.cpp @@ -55,9 +55,6 @@ AffineMap AffineTransforms::buildIndexAffineMap(miopen::TransformOp op) { } else if (transformAttr.getValue() == "Pad") { assert(srcDimAttr.size() == destDimAttr.size()); - // FIXME: wrw padding kernel, rank of memref is not the same with - // padding transform, so close it now until issue fixing - bool closeRightPadding = true; auto parameters = dimLayoutAttr.get("parameters").cast(); for (unsigned j = 0; j < srcDimAttr.size(); ++j) { // example of h and w pad parameters [0, 2, 3, 1] : @@ -68,51 +65,12 @@ AffineMap AffineTransforms::buildIndexAffineMap(miopen::TransformOp op) { // leftPad = 1 rightPad = 2 auto leftPad = parameters.getValue()[j * 2].cast().getInt(); - auto rightPad = - parameters.getValue()[j * 2 + 1].cast().getInt(); auto srcDim = srcDimAttr.getValue()[j].cast().getInt(); auto destDim = destDimAttr.getValue()[j].cast().getInt(); auto expr = getAffineDimExpr(destDim, op.getContext()) + getAffineConstantExpr(-leftPad, op.getContext()); - if (leftPad == 0 && rightPad != 0 & !closeRightPadding) { - // when leftPad == 0 , your original expr is just minus leftpad, but - // leftpad is zero, affinemap do not have minus out of boundary - // check depends on minus symbol , it will not do out of boundary - // check even rightpad part is oob example of leftPad == 0 && - // rightPad != 0: - // - // srcIndex0 srcIndex1 ... srcIndex[src_size - 1] - // dstIndex0 dstIndex1 ... dstIndex[src_size - 1] dstIndex[rightpad] - // index0 index1 index[src_size -1] index[src_size] - // can't find index[src_size] in src - // so we need to force it to do out of boundary check , - // - // the idea : - // dst index : - // dstIndex0 dstIndex1 ... dstIndex[src_size -1] dstIndex[rightpad] - // src index computed: - // srcIndex0 srcIndex1 ... srcIndex[src_size - 1] src_size+1 - // - // how to achieve it: - // dstIndex + (dstIndex/srcsize) + 1 - 1 - // - // the expr is : - // dstIndex + ceildiv(dstIndex+1/srcsize) - 1 - // the same with above but the - // minus symbol exist after optimization - // - // but if we use the equation above, when srcsize = 1 - // affinemap will optimized and no minus symbol - // just add more 1 can generate minus symbol - // the final expr is : - // dstIndex + ceildiv((dstIndex+2)/(srcsize+1)) - 1 - expr = ((getAffineDimExpr(destDim, op.getContext()) + 2) - .ceilDiv(inputShape[srcDim] + 1)) + - getAffineDimExpr(destDim, op.getContext()) - - getAffineConstantExpr(1, op.getContext()); - } affExprsMap.insert({srcDim, expr}); } } else if (transformAttr.getValue() == "Merge" || From f49cc9a5a2e4d05ebfbf9b40a733444a2f365dd4 Mon Sep 17 00:00:00 2001 From: "kevin.chang" Date: Thu, 10 Jun 2021 10:18:06 +0000 Subject: [PATCH 06/11] fix comments --- .../mlir/Dialect/MIOpen/LowerMIOpenOps.h | 21 ++++--- .../Dialect/MIOpen/lowering_top_level.mlir | 58 +++++++++++++++++-- 2 files changed, 67 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h index 015beef67d75..920f4b722d89 100644 --- a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h +++ b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h @@ -1667,7 +1667,7 @@ struct Conv2DRewritePattern : public OpRewritePattern { targetGemmDim2Attr.push_back(b.getNamedAttr( "names", b.getArrayAttr({b.getStringAttr(gemmMPad_name)}))); } - + // filter of forward, gemmM=k filterOobCheckDims.insert(nameToDims["k"]); } @@ -2366,7 +2366,9 @@ struct Conv2DRewritePattern : public OpRewritePattern { // we pad CRS to 32 or 64, then mlir can do gemm // we add more one transform to do pad - // input do not GemmM when forward and backward weights + // input forward : gemmK,gemmN + // backward weights: gemmK,gemmN + // so we don't need to pad gemmK bool inputCheckPadGemmK = false; bool inputCheckPadGemmN = false; inputCheckPadGemmK = @@ -2446,6 +2448,9 @@ struct Conv2DRewritePattern : public OpRewritePattern { b.getNamedAttr("upper_layer_names", b.getArrayAttr({b.getStringAttr(gemmKPad_name)}))); + // input gemmK fwd: CYX backward weights:NHW + // due to it's load , we can use whole dim in gemmK + // if it's store , use top one if (convOpType == miopen::ConvOpType::Conv2DOpType) { inputOobCheckDims.insert(nameToDims["ci"]); } else if (convOpType == miopen::ConvOpType::Conv2DBwdWeightOpType) { @@ -2464,7 +2469,6 @@ struct Conv2DRewritePattern : public OpRewritePattern { gemmDim2TargetName = b.getStringAttr(gemmNPad_name); paddingInputShape[2] = paddingInputShape[2] + gemmNExtra; - paddingInputShape[2] = paddingInputShape[2] + gemmNExtra; sourceGemmDim2Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); sourceGemmDim2Attr.push_back(b.getNamedAttr( @@ -2474,6 +2478,8 @@ struct Conv2DRewritePattern : public OpRewritePattern { b.getNamedAttr("upper_layer_names", b.getArrayAttr({b.getStringAttr(gemmNPad_name)}))); + // input fwd gemmN: nhw + // backward weights :CYX if (convOpType == miopen::ConvOpType::Conv2DOpType) { inputOobCheckDims.insert(nameToDims["ni"]); inputOobCheckDims.insert(nameToDims["hi"]); @@ -2806,7 +2812,7 @@ struct Conv2DRewritePattern : public OpRewritePattern { targetGemmDim1Attr.push_back( b.getNamedAttr("upper_layer_names", b.getArrayAttr({b.getStringAttr(gemmKPad_name)}))); - // wrw + // output wrw gemmK is nhw outputOobCheckDims.insert(nameToDims["no"]); outputOobCheckDims.insert(nameToDims["ho"]); outputOobCheckDims.insert(nameToDims["wo"]); @@ -2818,7 +2824,7 @@ struct Conv2DRewritePattern : public OpRewritePattern { isOutputPad = true; isGemmDim1Pad = true; gemmDim1TargetName = b.getStringAttr(gemmMPad_name); - // fwd, k + // output forward gemmM is k paddingOutputShape[1] = paddingOutputShape[1] + gemmMExtra; sourceGemmDim1Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); @@ -2833,7 +2839,7 @@ struct Conv2DRewritePattern : public OpRewritePattern { isOutputPad = true; isGemmDim2Pad = true; gemmDim2TargetName = b.getStringAttr(gemmMPad_name); - // wrw ,k + // output backward weights gemmM is k paddingOutputShape[2] = paddingOutputShape[2] + gemmMExtra; sourceGemmDim2Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); @@ -2849,7 +2855,7 @@ struct Conv2DRewritePattern : public OpRewritePattern { if (outputCheckPadGemmN) { if (arg2TargetLayoutName2 == "gemmN") { - // fwd dim 2,nhw + // fwd output gemmN is nhw isOutputPad = true; isGemmDim2Pad = true; gemmDim2TargetName = b.getStringAttr(gemmNPad_name); @@ -2864,6 +2870,7 @@ struct Conv2DRewritePattern : public OpRewritePattern { "names", b.getArrayAttr({b.getStringAttr(gemmNPad_name)}))); // FIXME: to set dim in merge transormation to oob store, // set only top dim or you will get zero values + // outputOobCheckDims.insert(nameToDims["no"]); } } diff --git a/mlir/test/Dialect/MIOpen/lowering_top_level.mlir b/mlir/test/Dialect/MIOpen/lowering_top_level.mlir index da6870ca619e..0c64a378684b 100644 --- a/mlir/test/Dialect/MIOpen/lowering_top_level.mlir +++ b/mlir/test/Dialect/MIOpen/lowering_top_level.mlir @@ -107,7 +107,7 @@ func @miopen_conv2d_bwd_data_f16(%filter : memref<1x128x8x3x3xf16>, %input : mem // CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmK", "gemmN"\].*}} // CHECK-NEXT: miopen.gridwise_gemm -func @miopen_conv2d_bwd_weight(%filter : memref<1x20x8x3x3xf32>, %input : memref<7x1x8x32x32xf32>, %output : memref<7x1x20x30x30xf32>) { +func @miopen_conv2d_bwd_weight(%filter : memref<1x128x8x3x3xf32>, %input : memref<128x1x8x32x32xf32>, %output : memref<128x1x128x30x30xf32>) { miopen.conv2d_bwd_weight(%filter, %input, %output) { arch = "gfx906", num_cu = 64, @@ -117,12 +117,60 @@ func @miopen_conv2d_bwd_weight(%filter : memref<1x20x8x3x3xf32>, %input : memref dilations = [1, 1], strides = [1, 1], padding = [0, 0, 0 ,0] - } : memref<1x20x8x3x3xf32>, memref<7x1x8x32x32xf32>, memref<7x1x20x30x30xf32> + } : memref<1x128x8x3x3xf32>, memref<128x1x8x32x32xf32>, memref<128x1x128x30x30xf32> return } // CHECK-LABEL: func {{@miopen_conv2d_bwd_weight.*%arg0.*%arg1.*%arg2}} // CHECK-NOT: miopen.conv2d_bwd_weight // CHECK-NEXT: {{miopen.transform\(%arg0\).* upper_layer_layout = \["gemmG", "gemmM", "gemmN"\].*}} +// CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmM", "gemmNPad"\].*}} +// CHECK-NEXT: miopen.transform(%arg1) +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmK", "gemmN"\].*}} +// CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmK", "gemmNPad"\].*}} +// CHECK-NEXT: {{miopen.transform\(%arg2\).* upper_layer_layout = \["gemmG", "gemmK", "gemmM"\].*}} +// CHECK-NEXT: miopen.gridwise_gemm + +func @miopen_conv2d_bwd_weight_f16(%filter : memref<1x128x8x3x3xf16>, %input : memref<128x1x8x32x32xf16>, %output : memref<128x1x128x30x30xf16>) { + miopen.conv2d_bwd_weight(%filter, %input, %output) { + arch = "gfx906", + num_cu = 64, + filter_layout = ["g", "k", "c", "y", "x"], + input_layout = ["ni", "gi", "ci", "hi", "wi"], + output_layout = ["no", "go", "ko", "ho", "wo"], + dilations = [1, 1], + strides = [1, 1], + padding = [0, 0, 0 ,0] + } : memref<1x128x8x3x3xf16>, memref<128x1x8x32x32xf16>, memref<128x1x128x30x30xf16> + return +} +// CHECK-LABEL: func {{@miopen_conv2d_bwd_weight_f16.*%arg0.*%arg1.*%arg2}} +// CHECK-NOT: miopen.conv2d_bwd_weight +// CHECK-NEXT: {{miopen.transform\(%arg0\).* upper_layer_layout = \["gemmG", "gemmM", "gemmN"\].*}} +// CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmM", "gemmNPad"\].*}} +// CHECK-NEXT: miopen.transform(%arg1) +// CHECK-NEXT: miopen.transform +// CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmK", "gemmN"\].*}} +// CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmK", "gemmNPad"\].*}} +// CHECK-NEXT: {{miopen.transform\(%arg2\).* upper_layer_layout = \["gemmG", "gemmK", "gemmM"\].*}} +// CHECK-NEXT: miopen.gridwise_gemm + +func @miopen_conv2d_bwd_weight_padALL(%filter : memref<1x20x8x3x3xf32>, %input : memref<7x1x8x32x32xf32>, %output : memref<7x1x20x30x30xf32>) { + miopen.conv2d_bwd_weight(%filter, %input, %output) { + arch = "gfx906", + num_cu = 64, + filter_layout = ["g", "k", "c", "y", "x"], + input_layout = ["ni", "gi", "ci", "hi", "wi"], + output_layout = ["no", "go", "ko", "ho", "wo"], + dilations = [1, 1], + strides = [1, 1], + padding = [0, 0, 0 ,0] + } : memref<1x20x8x3x3xf32>, memref<7x1x8x32x32xf32>, memref<7x1x20x30x30xf32> + return +} +// CHECK-LABEL: func {{@miopen_conv2d_bwd_weight_padALL.*%arg0.*%arg1.*%arg2}} +// CHECK-NOT: miopen.conv2d_bwd_weight +// CHECK-NEXT: {{miopen.transform\(%arg0\).* upper_layer_layout = \["gemmG", "gemmM", "gemmN"\].*}} // CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmMPad", "gemmNPad"\].*}} // CHECK-NEXT: miopen.transform(%arg1) // CHECK-NEXT: miopen.transform @@ -132,7 +180,7 @@ func @miopen_conv2d_bwd_weight(%filter : memref<1x20x8x3x3xf32>, %input : memref // CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmKPad", "gemmMPad"\].*}} // CHECK-NEXT: miopen.gridwise_gemm -func @miopen_conv2d_bwd_weight_f16(%filter : memref<1x20x8x3x3xf16>, %input : memref<7x1x8x32x32xf16>, %output : memref<7x1x20x30x30xf16>) { +func @miopen_conv2d_bwd_weight_padALL_f16(%filter : memref<1x20x8x3x3xf16>, %input : memref<7x1x8x32x32xf16>, %output : memref<7x1x20x30x30xf16>) { miopen.conv2d_bwd_weight(%filter, %input, %output) { arch = "gfx906", num_cu = 64, @@ -145,8 +193,8 @@ func @miopen_conv2d_bwd_weight_f16(%filter : memref<1x20x8x3x3xf16>, %input : me } : memref<1x20x8x3x3xf16>, memref<7x1x8x32x32xf16>, memref<7x1x20x30x30xf16> return } -// CHECK-LABEL: func {{@miopen_conv2d_bwd_weight.*%arg0.*%arg1.*%arg2}} -// CHECK-NOT: miopen.conv2d_bwd_weight_f16 +// CHECK-LABEL: func {{@miopen_conv2d_bwd_weight_padALL_f16.*%arg0.*%arg1.*%arg2}} +// CHECK-NOT: miopen.conv2d_bwd_weight // CHECK-NEXT: {{miopen.transform\(%arg0\).* upper_layer_layout = \["gemmG", "gemmM", "gemmN"\].*}} // CHECK-NEXT: {{miopen.transform.* upper_layer_layout = \["gemmG", "gemmMPad", "gemmNPad"\].*}} // CHECK-NEXT: miopen.transform(%arg1) From 49761c443ce2c67b72682d790138bd35b2897f57 Mon Sep 17 00:00:00 2001 From: "kevin.chang" Date: Fri, 11 Jun 2021 05:23:54 +0000 Subject: [PATCH 07/11] Revert "Revert "[HOTFIX] Temporarily disable index diff map when gemmKExtra > 0. (#261)"" This reverts commit 76f9b25ac7fce88be100441823a793717359184b. --- .../mlir/Dialect/MIOpen/LowerMIOpenOps.h | 48 +++++++++++++++---- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h index 920f4b722d89..61c0c342308d 100644 --- a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h +++ b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h @@ -7176,6 +7176,12 @@ struct ThreadwiseCopyRewritePattern // false : use the faster index diff map. auto legacyLoadAttr = op->getAttr("legacy_load"); auto legacyStoreAttr = op->getAttr("legacy_store"); + bool legacyLoad = + (legacyLoadAttr && + legacyLoadAttr.template cast().getValue() == true); + bool legacyStore = + (legacyStoreAttr && + legacyStoreAttr.template cast().getValue() == true); Optional composedSourceTransform; Optional composedDestTransform; @@ -7206,6 +7212,36 @@ struct ThreadwiseCopyRewritePattern return failure(); } + // FIXME. XXX. + // Workaround to obtain gemmKExtra attribute. + // And use it to override legacy load/store debug switch. + auto overrideLoadStoreHack = + [](const DictionaryAttr &transformSpec) -> bool { + if (transformSpec) { + Attribute metadataAttr = transformSpec.get("metadata"); + if (metadataAttr) { + ArrayAttr layeredTransformMetadata = + metadataAttr.template cast(); + for (unsigned iter = 0; iter < layeredTransformMetadata.size(); + ++iter) { + DictionaryAttr dictAttr = + layeredTransformMetadata[iter].template cast(); + auto gemmKExtraAttr = dictAttr.get("gemmKExtra"); + if (gemmKExtraAttr) { + auto gemmKExtra = + gemmKExtraAttr.template cast().getInt(); + if (gemmKExtra > 0) { + return true; + } + } + } + } + } + return false; + }; + legacyLoad = overrideLoadStoreHack(srcTransformSpec); + legacyStore = overrideLoadStoreHack(destTransformSpec); + // Populate the vector to hold source and dest coordinate. SmallVector sourceCoord; SmallVector destCoord; @@ -7314,8 +7350,7 @@ struct ThreadwiseCopyRewritePattern // wthe the metadata. // Only do such computation in the new approach where index diff maps // would be used. - if (!legacyLoadAttr || - (legacyLoadAttr.template cast().getValue() == false)) { + if (legacyLoad == false) { // Populate coorindates across the layers of transformations. if (srcTransformSpec) { Attribute metadataAttr = srcTransformSpec.get("metadata"); @@ -7346,8 +7381,7 @@ struct ThreadwiseCopyRewritePattern // wthe the metadata. // Only do such computation in the new approach where index diff maps // would be used. - if (!legacyStoreAttr || - (legacyStoreAttr.template cast().getValue() == false)) { + if (legacyStore == false) { // Populate coorindates across the layers of transformations. if (destTransformSpec) { Attribute metadataAttr = destTransformSpec.get("metadata"); @@ -7388,8 +7422,7 @@ struct ThreadwiseCopyRewritePattern bool toExit = false; do { // Use the old logic in case "legacy_load" attribute is specified. - if (legacyLoadAttr && - (legacyLoadAttr.template cast().getValue() == true)) { + if (legacyLoad == true) { computeTopAndBottomIndicesWithAffineMap( b, loc, srcUpperIndices, srcLowerIndices, sourceCoord, loopIVsPerAccessOrder, dimAccessOrder, layeredSourceTransform); @@ -7412,8 +7445,7 @@ struct ThreadwiseCopyRewritePattern b, loc, scalarValue, sourceElementType, destElementType); // Use the old logic in case "legacy_store" attribute is specified. - if (legacyStoreAttr && - (legacyStoreAttr.template cast().getValue() == true)) { + if (legacyStore == true) { computeTopAndBottomIndicesWithAffineMap( b, loc, destUpperIndices, destLowerIndices, destCoord, loopIVsPerAccessOrder, dimAccessOrder, layeredDestTransform); From 7f7f730b42323924cf377086aa888b30825f54cc Mon Sep 17 00:00:00 2001 From: "kevin.chang" Date: Fri, 11 Jun 2021 08:29:57 +0000 Subject: [PATCH 08/11] enbale all padding --- .../mlir/Dialect/MIOpen/LowerMIOpenOps.h | 44 ++++++++++++++++--- .../MIOpen/lowering_padding_kernel.mlir | 12 ++--- 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h index 61c0c342308d..375d0b9f3b1a 100644 --- a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h +++ b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h @@ -1509,11 +1509,13 @@ struct Conv2DRewritePattern : public OpRewritePattern { "gridwise_gemm_argument_position", b.getI32IntegerAttr(fields.gridwiseGemmArgumentPosition[0]))); - // set gemmMExtra & gemmKExtra + // set gemmMExtra & gemmKExtra & gemmNExtra transformedFilterAttrs.push_back( b.getNamedAttr("gemmMExtra", b.getI32IntegerAttr(gemmMExtra))); transformedFilterAttrs.push_back( b.getNamedAttr("gemmKExtra", b.getI32IntegerAttr(gemmKExtra))); + transformedFilterAttrs.push_back( + b.getNamedAttr("gemmNExtra", b.getI32IntegerAttr(gemmNExtra))); // set needExtraPad transformedFilterAttrs.push_back( b.getNamedAttr("extraPad", b.getBoolAttr(needExtraPad))); @@ -1626,9 +1628,13 @@ struct Conv2DRewritePattern : public OpRewritePattern { b.getArrayAttr({b.getStringAttr(gemmKPad_name)}))); } // filter of forward, gemmK=c*y*x - filterOobCheckDims.insert(nameToDims["c"]); - filterOobCheckDims.insert(nameToDims["y"]); - filterOobCheckDims.insert(nameToDims["x"]); + if (filterYDim == 2) { + // kyxc + filterOobCheckDims.insert(nameToDims["y"]); + } else { + // kcyx + filterOobCheckDims.insert(nameToDims["c"]); + } } if (filterCheckPadGemmM) { @@ -1918,11 +1924,13 @@ struct Conv2DRewritePattern : public OpRewritePattern { reorderedPaddedInputDimNames.begin(), reorderedPaddedInputDimNames.end())))); - // set gemmKExtra & gemmNExtra + // set gemmKExtra & gemmNExtra & gemmNExtra paddedInputAttrs.push_back( b.getNamedAttr("gemmKExtra", b.getI32IntegerAttr(gemmKExtra))); paddedInputAttrs.push_back( b.getNamedAttr("gemmNExtra", b.getI32IntegerAttr(gemmNExtra))); + paddedInputAttrs.push_back( + b.getNamedAttr("gemmMExtra", b.getI32IntegerAttr(gemmMExtra))); // set needExtraPad paddedInputAttrs.push_back( b.getNamedAttr("extraPad", b.getBoolAttr(needExtraPad))); @@ -2896,6 +2904,14 @@ struct Conv2DRewritePattern : public OpRewritePattern { layoutAttr2.append(targetGemmDim2Attr.begin(), targetGemmDim2Attr.end()); layoutAttr2.append(sourceGemmDim2Attr.begin(), sourceGemmDim2Attr.end()); + // set gemmKExtra & gemmNExtra & gemmNExtra + paddingOutputAttrs.push_back( + b.getNamedAttr("gemmKExtra", b.getI32IntegerAttr(gemmKExtra))); + paddingOutputAttrs.push_back( + b.getNamedAttr("gemmNExtra", b.getI32IntegerAttr(gemmNExtra))); + paddingOutputAttrs.push_back( + b.getNamedAttr("gemmMExtra", b.getI32IntegerAttr(gemmMExtra))); + paddingOutputAttrs.push_back(b.getNamedAttr( "layout", b.getArrayAttr({ b.getDictionaryAttr({ArrayRef( @@ -7227,6 +7243,8 @@ struct ThreadwiseCopyRewritePattern DictionaryAttr dictAttr = layeredTransformMetadata[iter].template cast(); auto gemmKExtraAttr = dictAttr.get("gemmKExtra"); + auto gemmMExtraAttr = dictAttr.get("gemmMExtra"); + auto gemmNExtraAttr = dictAttr.get("gemmNExtra"); if (gemmKExtraAttr) { auto gemmKExtra = gemmKExtraAttr.template cast().getInt(); @@ -7234,6 +7252,22 @@ struct ThreadwiseCopyRewritePattern return true; } } + + if (gemmMExtraAttr) { + auto gemmMExtra = + gemmMExtraAttr.template cast().getInt(); + if (gemmMExtra > 0) { + return true; + } + } + + if (gemmNExtraAttr) { + auto gemmNExtra = + gemmNExtraAttr.template cast().getInt(); + if (gemmNExtra > 0) { + return true; + } + } } } } diff --git a/mlir/test/Dialect/MIOpen/lowering_padding_kernel.mlir b/mlir/test/Dialect/MIOpen/lowering_padding_kernel.mlir index 891af04d0081..e128ff462af7 100644 --- a/mlir/test/Dialect/MIOpen/lowering_padding_kernel.mlir +++ b/mlir/test/Dialect/MIOpen/lowering_padding_kernel.mlir @@ -18,9 +18,9 @@ func @miopen_conv2d_kcyx_nchw_nkhw_padding_kernel(%filter : memref<32x128x2x3x3x return } // CHECK-LABEL: func @miopen_conv2d -// CHECK-NEXT: {{miopen.transform.*{.*extraPad = true, gemmKExtra = 14 : i32, gemmMExtra = 0 : i32,.*}.*memref.*memref}} +// CHECK-NEXT: {{miopen.transform.*{.*extraPad = true, gemmKExtra = 14 : i32, gemmMExtra = 0 : i32, gemmNExtra = 0 : i32,.*}.*memref.*memref}} // CHECK-NEXT: {{miopen.transform.*{.*}.*memref.*memref}} -// CHECK-NEXT: {{miopen.transform.*{.*extraPad = true, gemmKExtra = 14 : i32, gemmNExtra = 0 : i32,.*}.*memref.*memref}} +// CHECK-NEXT: {{miopen.transform.*{.*extraPad = true, gemmKExtra = 14 : i32, gemmMExtra = 0 : i32, gemmNExtra = 0 : i32,.*}.*memref.*memref}} // CHECK-NEXT: {{miopen.transform.*{.*}.*memref.*memref}} // CHECK-NEXT: {{miopen.transform.*{.*}.*memref.*memref}} // CHECK-NEXT: {{miopen.transform.*{.*}.*memref.*memref}} @@ -41,8 +41,8 @@ func @miopen_conv2d_kcyx_nchw_nkhw_no_extra_padding(%filter : memref<1x128x64x3x return } // CHECK-LABEL: func @miopen_conv2d -// CHECK-NEXT: {{miopen.transform.*{.*extraPad = false, gemmKExtra = 0 : i32, gemmMExtra = 0 : i32,.*}.*memref.*memref}} -// CHECK-NEXT: {{miopen.transform.*{.*extraPad = false, gemmKExtra = 0 : i32, gemmNExtra = 0 : i32,.*}.*memref.*memref}} +// CHECK-NEXT: {{miopen.transform.*{.*extraPad = false, gemmKExtra = 0 : i32, gemmMExtra = 0 : i32, gemmNExtra = 0 : i32,.*}.*memref.*memref}} +// CHECK-NEXT: {{miopen.transform.*{.*extraPad = false, gemmKExtra = 0 : i32, gemmMExtra = 0 : i32, gemmNExtra = 0 : i32,.*}.*memref.*memref}} // CHECK-NEXT: {{miopen.transform.*{.*}.*memref.*memref}} // CHECK-NEXT: {{miopen.transform.*{.*}.*memref.*memref}} // CHECK-NEXT: {{miopen.transform.*{.*extraPad = false, gemmMExtra = 0 : i32, gemmNExtra = 0 : i32,.*}.*memref.*memref}} @@ -62,9 +62,9 @@ func @miopen_conv2d_kcyx_nchw_nkhw_partial_padding_kernel(%filter : memref<32x12 return } // CHECK-LABEL: func @miopen_conv2d -// CHECK-NEXT: {{miopen.transform.*{.*extraPad = true, gemmKExtra = 14 : i32, gemmMExtra = 0 : i32,.*}.*memref.*memref}} +// CHECK-NEXT: {{miopen.transform.*{.*extraPad = true, gemmKExtra = 14 : i32, gemmMExtra = 0 : i32, gemmNExtra = 0 : i32,.*}.*memref.*memref}} // CHECK-NEXT: {{miopen.transform.*{.*}.*memref.*memref}} -// CHECK-NEXT: {{miopen.transform.*{.*extraPad = true, gemmKExtra = 14 : i32, gemmNExtra = 0 : i32,.*}.*memref.*memref}} +// CHECK-NEXT: {{miopen.transform.*{.*extraPad = true, gemmKExtra = 14 : i32, gemmMExtra = 0 : i32, gemmNExtra = 0 : i32,.*}.*memref.*memref}} // CHECK-NEXT: {{miopen.transform.*{.*}.*memref.*memref}} // CHECK-NEXT: {{miopen.transform.*{.*}.*memref.*memref}} // CHECK-NEXT: {{miopen.transform.*{.*}.*memref.*memref}} From 6aebd31a30eda0cb4dfeaff87dee38ecb535ab7f Mon Sep 17 00:00:00 2001 From: "kevin.chang" Date: Fri, 11 Jun 2021 09:26:46 +0000 Subject: [PATCH 09/11] fix comments --- .../mlir/Dialect/MIOpen/LowerMIOpenOps.h | 38 +++++++++++++------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h index 375d0b9f3b1a..db9b10d6866c 100644 --- a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h +++ b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h @@ -1542,7 +1542,8 @@ struct Conv2DRewritePattern : public OpRewritePattern { int64_t nonGemmMSize = transformedFilterShape[1]; int64_t gemmMSize = transformedFilterShape[2]; // filter pad start - // filter : K & CRS , if CRS is under 64 or 32 + // K:output channel, C:input channel,R:filter high,S:filter width + // filter dim : K & merge(C,R,S) , if C*R*S is under 64 or 32 // we pad CRS to 32 or 64, then mlir can do gemm // we add more one transform to do pad bool filterCheckPadGemmM = false; @@ -1613,7 +1614,7 @@ struct Conv2DRewritePattern : public OpRewritePattern { isGemmDim1Pad = true; gemmDim1TargetName = b.getStringAttr(gemmKPad_name); - // fwd + // forward paddingFilterShape[1] = nonGemmMSize + gemmKExtra; sourceGemmDim1Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); @@ -1639,7 +1640,7 @@ struct Conv2DRewritePattern : public OpRewritePattern { if (filterCheckPadGemmM) { if (arg0TargetLayoutName1 == "gemmM") { - // wrw + // backward weights isFilterPad = true; isGemmDim1Pad = true; gemmDim1TargetName = b.getStringAttr(gemmMPad_name); @@ -1657,11 +1658,12 @@ struct Conv2DRewritePattern : public OpRewritePattern { targetGemmDim1Attr.push_back(b.getNamedAttr( "names", b.getArrayAttr({b.getStringAttr(gemmMPad_name)}))); } else if (arg0TargetLayoutName2 == "gemmM") { - // fwd + // forward isFilterPad = true; isGemmDim2Pad = true; gemmDim2TargetName = b.getStringAttr(gemmMPad_name); paddingFilterShape[2] = gemmMSize + gemmMExtra; + // gemmM = k when forward, padd gemmMExtra sourceGemmDim2Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); sourceGemmDim2Attr.push_back( @@ -1679,11 +1681,13 @@ struct Conv2DRewritePattern : public OpRewritePattern { if (filterCheckPadGemmN) { if (arg0TargetLayoutName2 == "gemmN") { - // wrw + // backward weights isFilterPad = true; isGemmDim2Pad = true; gemmDim2TargetName = b.getStringAttr(gemmNPad_name); paddingFilterShape[1] = nonGemmMSize + gemmNExtra; + // backward weights input: gemmK, gemmN + // so padd gemmNExtra sourceGemmDim2Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); sourceGemmDim2Attr.push_back( @@ -2443,7 +2447,9 @@ struct Conv2DRewritePattern : public OpRewritePattern { if (arg1TargetLayoutName1 == "gemmK") { isInputPad = true; isGemmDim1Pad = true; - // fwd is cyx ,wrw is nhw + // both forward and backward weights dim1 of input matrix + // are gemmK ,but forward gemmK is combining c,y,x + // backward weights gemmK is combining n,h,w gemmDim1TargetName = b.getStringAttr(gemmKPad_name); paddingInputShape[1] = paddingInputShape[1] + gemmKExtra; @@ -2476,7 +2482,8 @@ struct Conv2DRewritePattern : public OpRewritePattern { isGemmDim2Pad = true; gemmDim2TargetName = b.getStringAttr(gemmNPad_name); paddingInputShape[2] = paddingInputShape[2] + gemmNExtra; - + // both forward and backward weights have the same dim2 gemmN + // so padding gemmNExtra sourceGemmDim2Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); sourceGemmDim2Attr.push_back(b.getNamedAttr( @@ -2486,8 +2493,8 @@ struct Conv2DRewritePattern : public OpRewritePattern { b.getNamedAttr("upper_layer_names", b.getArrayAttr({b.getStringAttr(gemmNPad_name)}))); - // input fwd gemmN: nhw - // backward weights :CYX + // input forward gemmN: n,h,w + // backward weights gemmN :C,Y,X if (convOpType == miopen::ConvOpType::Conv2DOpType) { inputOobCheckDims.insert(nameToDims["ni"]); inputOobCheckDims.insert(nameToDims["hi"]); @@ -2739,7 +2746,8 @@ struct Conv2DRewritePattern : public OpRewritePattern { // output padding start // output matrix dim: K & NHW // when backward weight , GEMMK = NHW - // If Nhw is under 32 or 64 ,we pad it to 32 or 64 + // N:batch size, H:output height ,W:output width + // If size of N*h*w is under 32 or 64 ,we pad it to 32 or 64 // then mlir can do gemm // we just add more one transform to do it @@ -2809,7 +2817,8 @@ struct Conv2DRewritePattern : public OpRewritePattern { isOutputPad = true; isGemmDim1Pad = true; gemmDim1TargetName = b.getStringAttr(gemmKPad_name); - // wrw dim 1 is nhw + // backward weights dim 1 is composing of (N,H,W) + // N:batch size, H: output height ,W:output width paddingOutputShape[1] = paddingOutputShape[1] + gemmKExtra; sourceGemmDim1Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); @@ -2834,6 +2843,8 @@ struct Conv2DRewritePattern : public OpRewritePattern { gemmDim1TargetName = b.getStringAttr(gemmMPad_name); // output forward gemmM is k paddingOutputShape[1] = paddingOutputShape[1] + gemmMExtra; + // output forward gemmM is k + // so padding gemmMExtra sourceGemmDim1Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); sourceGemmDim1Attr.push_back(b.getNamedAttr( @@ -2848,6 +2859,7 @@ struct Conv2DRewritePattern : public OpRewritePattern { isGemmDim2Pad = true; gemmDim2TargetName = b.getStringAttr(gemmMPad_name); // output backward weights gemmM is k + // so padding gemmMExtra paddingOutputShape[2] = paddingOutputShape[2] + gemmMExtra; sourceGemmDim2Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); @@ -2863,11 +2875,13 @@ struct Conv2DRewritePattern : public OpRewritePattern { if (outputCheckPadGemmN) { if (arg2TargetLayoutName2 == "gemmN") { - // fwd output gemmN is nhw + // forward output gemmN is nhw isOutputPad = true; isGemmDim2Pad = true; gemmDim2TargetName = b.getStringAttr(gemmNPad_name); paddingOutputShape[2] = paddingOutputShape[2] + gemmNExtra; + // forward output gemmN is combining(N,H,W) + // so padding gemmNExtra sourceGemmDim2Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); sourceGemmDim2Attr.push_back(b.getNamedAttr( From a7eac5d233003705aee135c9332e1944f620f011 Mon Sep 17 00:00:00 2001 From: "kevin.chang" Date: Sun, 13 Jun 2021 08:46:20 +0000 Subject: [PATCH 10/11] fix comments --- .../mlir/Dialect/MIOpen/LowerMIOpenOps.h | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h index db9b10d6866c..f77318a885ff 100644 --- a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h +++ b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h @@ -1542,7 +1542,7 @@ struct Conv2DRewritePattern : public OpRewritePattern { int64_t nonGemmMSize = transformedFilterShape[1]; int64_t gemmMSize = transformedFilterShape[2]; // filter pad start - // K:output channel, C:input channel,R:filter high,S:filter width + // K:output channel, C:input channel,R:filter height,S:filter width // filter dim : K & merge(C,R,S) , if C*R*S is under 64 or 32 // we pad CRS to 32 or 64, then mlir can do gemm // we add more one transform to do pad @@ -1662,8 +1662,9 @@ struct Conv2DRewritePattern : public OpRewritePattern { isFilterPad = true; isGemmDim2Pad = true; gemmDim2TargetName = b.getStringAttr(gemmMPad_name); - paddingFilterShape[2] = gemmMSize + gemmMExtra; // gemmM = k when forward, padd gemmMExtra + paddingFilterShape[2] = gemmMSize + gemmMExtra; + // gemmM = k when forward sourceGemmDim2Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); sourceGemmDim2Attr.push_back( @@ -1685,9 +1686,10 @@ struct Conv2DRewritePattern : public OpRewritePattern { isFilterPad = true; isGemmDim2Pad = true; gemmDim2TargetName = b.getStringAttr(gemmNPad_name); - paddingFilterShape[1] = nonGemmMSize + gemmNExtra; // backward weights input: gemmK, gemmN // so padd gemmNExtra + + paddingFilterShape[1] = nonGemmMSize + gemmNExtra; sourceGemmDim2Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); sourceGemmDim2Attr.push_back( @@ -2481,9 +2483,10 @@ struct Conv2DRewritePattern : public OpRewritePattern { isInputPad = true; isGemmDim2Pad = true; gemmDim2TargetName = b.getStringAttr(gemmNPad_name); - paddingInputShape[2] = paddingInputShape[2] + gemmNExtra; // both forward and backward weights have the same dim2 gemmN // so padding gemmNExtra + + paddingInputShape[2] = paddingInputShape[2] + gemmNExtra; sourceGemmDim2Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); sourceGemmDim2Attr.push_back(b.getNamedAttr( @@ -2829,7 +2832,8 @@ struct Conv2DRewritePattern : public OpRewritePattern { targetGemmDim1Attr.push_back( b.getNamedAttr("upper_layer_names", b.getArrayAttr({b.getStringAttr(gemmKPad_name)}))); - // output wrw gemmK is nhw + // output backward weights gemmK is composed of n,h,w, check all dims + // due to it's load , not store ,if it's store ,check only no dim outputOobCheckDims.insert(nameToDims["no"]); outputOobCheckDims.insert(nameToDims["ho"]); outputOobCheckDims.insert(nameToDims["wo"]); @@ -2850,7 +2854,7 @@ struct Conv2DRewritePattern : public OpRewritePattern { sourceGemmDim1Attr.push_back(b.getNamedAttr( "parameters", b.getArrayAttr({b.getI32IntegerAttr(0), b.getI32IntegerAttr(gemmMExtra)}))); - + // output forward gemmM is k, check the ko dim targetGemmDim1Attr.push_back(b.getNamedAttr( "names", b.getArrayAttr({b.getStringAttr(gemmMPad_name)}))); outputOobCheckDims.insert(nameToDims["ko"]); @@ -2866,7 +2870,7 @@ struct Conv2DRewritePattern : public OpRewritePattern { sourceGemmDim2Attr.push_back(b.getNamedAttr( "parameters", b.getArrayAttr({b.getI32IntegerAttr(0), b.getI32IntegerAttr(gemmMExtra)}))); - + // output backward weights gemmM is k, xcv targetGemmDim2Attr.push_back(b.getNamedAttr( "names", b.getArrayAttr({b.getStringAttr(gemmMPad_name)}))); outputOobCheckDims.insert(nameToDims["ko"]); @@ -2879,9 +2883,10 @@ struct Conv2DRewritePattern : public OpRewritePattern { isOutputPad = true; isGemmDim2Pad = true; gemmDim2TargetName = b.getStringAttr(gemmNPad_name); - paddingOutputShape[2] = paddingOutputShape[2] + gemmNExtra; // forward output gemmN is combining(N,H,W) // so padding gemmNExtra + + paddingOutputShape[2] = paddingOutputShape[2] + gemmNExtra; sourceGemmDim2Attr.push_back( b.getNamedAttr("transformation", b.getStringAttr("Pad"))); sourceGemmDim2Attr.push_back(b.getNamedAttr( @@ -2892,7 +2897,7 @@ struct Conv2DRewritePattern : public OpRewritePattern { "names", b.getArrayAttr({b.getStringAttr(gemmNPad_name)}))); // FIXME: to set dim in merge transormation to oob store, // set only top dim or you will get zero values - // + // output forward gemmM is composed of n , h ,w, check the top dim :no outputOobCheckDims.insert(nameToDims["no"]); } } @@ -7256,6 +7261,9 @@ struct ThreadwiseCopyRewritePattern ++iter) { DictionaryAttr dictAttr = layeredTransformMetadata[iter].template cast(); + // enable workaround when padding kernel, + // if gemmKExtra || gemmMExtra || gemmNExtraAttr + // use workaround to skip index map errors auto gemmKExtraAttr = dictAttr.get("gemmKExtra"); auto gemmMExtraAttr = dictAttr.get("gemmMExtra"); auto gemmNExtraAttr = dictAttr.get("gemmNExtra"); From 2c895e0eed23752c1e6613e1ecd4c840575cc259 Mon Sep 17 00:00:00 2001 From: "kevin.chang" Date: Mon, 14 Jun 2021 15:02:40 +0000 Subject: [PATCH 11/11] fix comments --- mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h index f77318a885ff..7e1792b90880 100644 --- a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h +++ b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h @@ -1542,9 +1542,9 @@ struct Conv2DRewritePattern : public OpRewritePattern { int64_t nonGemmMSize = transformedFilterShape[1]; int64_t gemmMSize = transformedFilterShape[2]; // filter pad start - // K:output channel, C:input channel,R:filter height,S:filter width - // filter dim : K & merge(C,R,S) , if C*R*S is under 64 or 32 - // we pad CRS to 32 or 64, then mlir can do gemm + // K:output channel, C:input channel,Y:filter height,X:filter width + // filter dim : K & merge(C,Y,X) , if C*Y*X is under 64 or 32 + // we pad CYX to 32 or 64, then mlir can do gemm // we add more one transform to do pad bool filterCheckPadGemmM = false; bool filterCheckPadGemmK = false; @@ -1662,7 +1662,7 @@ struct Conv2DRewritePattern : public OpRewritePattern { isFilterPad = true; isGemmDim2Pad = true; gemmDim2TargetName = b.getStringAttr(gemmMPad_name); - // gemmM = k when forward, padd gemmMExtra + // gemmM = k when forward, pad gemmMExtra paddingFilterShape[2] = gemmMSize + gemmMExtra; // gemmM = k when forward sourceGemmDim2Attr.push_back( @@ -2870,7 +2870,8 @@ struct Conv2DRewritePattern : public OpRewritePattern { sourceGemmDim2Attr.push_back(b.getNamedAttr( "parameters", b.getArrayAttr({b.getI32IntegerAttr(0), b.getI32IntegerAttr(gemmMExtra)}))); - // output backward weights gemmM is k, xcv + // output backward weights gemmM is k, + // so padding gemmMExtra targetGemmDim2Attr.push_back(b.getNamedAttr( "names", b.getArrayAttr({b.getStringAttr(gemmMPad_name)}))); outputOobCheckDims.insert(nameToDims["ko"]);