diff --git a/mlir/include/mlir/Dialect/MIOpen/Generator/Conv2dGenerator.h b/mlir/include/mlir/Dialect/MIOpen/Generator/Conv2dGenerator.h index 1d4ce8c0ee08..6b520558b820 100644 --- a/mlir/include/mlir/Dialect/MIOpen/Generator/Conv2dGenerator.h +++ b/mlir/include/mlir/Dialect/MIOpen/Generator/Conv2dGenerator.h @@ -42,6 +42,9 @@ class Conv2dGenerator { SmallVector filterDimension; SmallVector inputDimension; SmallVector outputDimension; + + int filterHeight; + int filterWidth; }; Conv2dGenerator(const std::string &arch = "", int num_cu = 0, @@ -98,6 +101,7 @@ class Conv2dGenerator { }); return permutation; } + int getBwdDataKernelCount() const; // Generator config Config config; diff --git a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h index 7abf00701001..f02c1e817824 100644 --- a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h +++ b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h @@ -2407,7 +2407,7 @@ struct Conv2DRewritePattern : public OpRewritePattern { LogicalResult backwardData(T op, PatternRewriter &b) const { auto loc = op.getLoc(); - + auto gemmIdAttr = op->template getAttrOfType("gemm_id"); auto archAttr = op->template getAttrOfType("arch"); auto numCuAttr = op->template getAttrOfType("num_cu"); @@ -2500,8 +2500,8 @@ struct Conv2DRewritePattern : public OpRewritePattern { auto gcdStrideDilationH = math::gcd(strideH, dilationH); auto gcdStrideDilationW = math::gcd(strideW, dilationW); - auto yTilda = dilationH / gcdStrideDilationH; - auto xTilda = dilationW / gcdStrideDilationW; + auto yTilda = strideH / gcdStrideDilationH; + auto xTilda = strideW / gcdStrideDilationW; auto yDot = math::integer_divide_ceil(y, yTilda); auto xDot = math::integer_divide_ceil(x, xTilda); @@ -2522,7 +2522,28 @@ struct Conv2DRewritePattern : public OpRewritePattern { auto hTildaSlice = iHTildaRight - iHTildaLeft; auto wTildaSlice = iWTildaRight - iWTildaLeft; - auto gemmId = 0; + auto getGemmId = [&](int kernelId) { + // kernelId 0 must be gemmId 0 + if (kernelId <= 0) + return 0; + + llvm::SmallVector gemmIds; + for (int gemmId = 0; gemmId < yTilda * xTilda; gemmId++) { + // gemm_k size is different for each GEMM + const auto iYTilda = gemmId / xTilda; + const auto iXTilda = gemmId % xTilda; + + auto yDotSlice = math::integer_divide_ceil(y - iYTilda, yTilda); + auto xDotSlice = math::integer_divide_ceil(x - iXTilda, xTilda); + // gemmK must > 0, otherwise not need to run + if (yDotSlice * xDotSlice > 0) { + gemmIds.push_back(gemmId); + } + } + assert(gemmIds.size() > kernelId); + return gemmIds[kernelId]; + }; + auto gemmId = getGemmId(gemmIdAttr.getInt()); auto iYTilda = gemmId / xTilda; auto iXTilda = gemmId % xTilda; auto yDotSlice = math::integer_divide_ceil(y - iYTilda, yTilda); @@ -3766,6 +3787,7 @@ struct Conv2DRewritePattern : public OpRewritePattern { // Set attributes for gridwise_gemm op. llvm::SmallVector gridwiseGemmAttrs{ + b.getNamedAttr("gemm_id", gemmIdAttr), b.getNamedAttr("arch", archAttr), b.getNamedAttr("num_cu", numCuAttr), b.getNamedAttr("filter_layout", filterLayoutAttr), diff --git a/mlir/include/mlir/Dialect/MIOpen/Tuning/ConvContext.h b/mlir/include/mlir/Dialect/MIOpen/Tuning/ConvContext.h index 8d92cf9531f5..12c816c832b4 100644 --- a/mlir/include/mlir/Dialect/MIOpen/Tuning/ConvContext.h +++ b/mlir/include/mlir/Dialect/MIOpen/Tuning/ConvContext.h @@ -28,15 +28,17 @@ struct ConvolutionContext : SQLiteSerializable { llvm::SmallVector strideVal; llvm::SmallVector dilationVal; llvm::SmallVector paddingVal; + int gemmId; ConvolutionContext(const llvm::SmallString<8> &architecture, int numCu, miopen::ConvOpType op, llvm::StringMap> dim, llvm::SmallVector stride, llvm::SmallVector dilation, - llvm::SmallVector padding) + llvm::SmallVector padding, int gemmid) : arch(architecture), num_cu(numCu), opType(op), dimIndexVal(dim), - strideVal(stride), dilationVal(dilation), paddingVal(padding) {} + strideVal(stride), dilationVal(dilation), paddingVal(padding), + gemmId(gemmid) {} llvm::StringMap> getDimIndexVal() const { return dimIndexVal; @@ -146,6 +148,11 @@ template static ConvolutionContext populateConvContext(T &op) { auto archVal = op->template getAttrOfType("arch").getValue(); int numCuVal = op->template getAttrOfType("num_cu").getInt(); + auto gemmIdAttr = op->template getAttrOfType("gemm_id"); + int gemmId = 0; + if (gemmIdAttr) { + gemmId = gemmIdAttr.getInt(); + } llvm::StringMap> dimIndexVal; @@ -176,8 +183,8 @@ template static ConvolutionContext populateConvContext(T &op) { llvm::SmallVector paddingVal; populateSeqVal(paddingAttr, paddingVal); - return {archVal, numCuVal, opType, dimIndexVal, - strideVal, dilationVal, paddingVal}; + return {archVal, numCuVal, opType, dimIndexVal, + strideVal, dilationVal, paddingVal, gemmId}; } } // namespace mlir diff --git a/mlir/include/mlir/Dialect/MIOpen/Tuning/GridwiseGemmParams.h b/mlir/include/mlir/Dialect/MIOpen/Tuning/GridwiseGemmParams.h index 1643dc3f827c..d9ed268e5575 100644 --- a/mlir/include/mlir/Dialect/MIOpen/Tuning/GridwiseGemmParams.h +++ b/mlir/include/mlir/Dialect/MIOpen/Tuning/GridwiseGemmParams.h @@ -111,14 +111,8 @@ class PopulateParamsBase { input1GemmKVectorizable = true; } } else if (opType == mlir::miopen::ConvOpType::Conv2DBwdDataOpType) { - // When K is the fastest changing dimension(3), - // gemmK dimension is vectorizable, gemmM is not, and vice versa. - // Vectorization width depending on length of K. - if (dimIndexVal["k"].first == 4) { - input1GemmKVectorizable = true; - } else { - input1GemmKVectorizable = false; - } + // always load gemmM first + input1GemmKVectorizable = false; } else if (opType == mlir::miopen::ConvOpType::Conv2DBwdWeightOpType) { // When K is the fastest changing dimension, // gemmM dimension is vectorizable, gemmK is not, and vice versa. @@ -200,6 +194,23 @@ class PopulateParamsBase { } } + static void obtainBwdDataFilterVecLen(ConvolutionContext &ctx, + int64_t &vecLen) { + auto dimIndexVal = ctx.dimIndexVal; + // Vectorization length logic is the same for forward and bwd_data + if (dimIndexVal["c"].first == 4) { + vecLen = dimIndexVal["c"].second; + } else if (dimIndexVal["c"].first == 2) { + // C's position is at 2, vectorization legnth depend last two dimension + if (dimIndexVal["y"].second == 1 && dimIndexVal["x"].second == 1) { + vecLen = dimIndexVal["c"].second; + } else { + vecLen = 1; + } + } else { + vecLen = 1; + } + } static void obtainInputVecLen(ConvolutionContext &ctx, int64_t &vecLen) { auto dimIndexVal = ctx.dimIndexVal; if (dimIndexVal["ni"].first == 4) { @@ -216,6 +227,26 @@ class PopulateParamsBase { vecLen = 1; } } + static void obtainBwdDataOutputVecLen(ConvolutionContext &ctx, + int64_t &vecLen) { + auto dimIndexVal = ctx.dimIndexVal; + if (dimIndexVal["ko"].first == 4) { + vecLen = dimIndexVal["ko"].second; + } else if (dimIndexVal["no"].first == 4) { + vecLen = dimIndexVal["no"].second; + } else if (dimIndexVal["no"].first == 0) { + if (dimIndexVal["ho"].first == 3 && dimIndexVal["wo"].first == 4) { + if (dimIndexVal["y"].second == 1 && dimIndexVal["x"].second == 1) + vecLen = dimIndexVal["ho"].second * dimIndexVal["wo"].second; + else + vecLen = 1; + } else + vecLen = 1; + } else { + vecLen = 1; + } + } + static void obtainOutputVecLen(ConvolutionContext &ctx, int64_t &vecLen) { auto dimIndexVal = ctx.dimIndexVal; if (dimIndexVal["ko"].first == 4) { @@ -250,7 +281,7 @@ class PopulateParamsBase { if (opType == mlir::miopen::ConvOpType::Conv2DOpType) { obtainFilterVecLen(ctx, vecLen); } else if (opType == mlir::miopen::ConvOpType::Conv2DBwdDataOpType) { - obtainFilterVecLen(ctx, vecLen); + obtainBwdDataFilterVecLen(ctx, vecLen); } else if (opType == mlir::miopen::ConvOpType::Conv2DBwdWeightOpType) { obtainOutputVecLen(ctx, vecLen); } @@ -261,7 +292,7 @@ class PopulateParamsBase { if (opType == mlir::miopen::ConvOpType::Conv2DOpType) { obtainInputVecLen(ctx, vecLen); } else if (opType == mlir::miopen::ConvOpType::Conv2DBwdDataOpType) { - obtainOutputVecLen(ctx, vecLen); + obtainBwdDataOutputVecLen(ctx, vecLen); } else if (opType == mlir::miopen::ConvOpType::Conv2DBwdWeightOpType) { obtainInputVecLen(ctx, vecLen); } @@ -385,8 +416,8 @@ class PopulateParamsBase { auto gcdStrideDilationH = math::gcd(strideH, dilationH); auto gcdStrideDilationW = math::gcd(strideW, dilationW); - auto yTilda = dilationH / gcdStrideDilationH; - auto xTilda = dilationW / gcdStrideDilationW; + auto yTilda = strideH / gcdStrideDilationH; + auto xTilda = strideW / gcdStrideDilationW; auto hTilda = ho + math::integer_divide_ceil(dilationH * (y - 1), strideH); @@ -406,9 +437,9 @@ class PopulateParamsBase { auto hTildaSlice = iHTildaRight - iHTildaLeft; auto wTildaSlice = iWTildaRight - iWTildaLeft; - auto gemm_id = 0; - auto iYTilda = gemm_id / xTilda; - auto iXTilda = gemm_id % xTilda; + auto gemmId = ctx.gemmId; + auto iYTilda = gemmId / xTilda; + auto iXTilda = gemmId % xTilda; auto yDotSlice = math::integer_divide_ceil(y - iYTilda, yTilda); auto xDotSlice = math::integer_divide_ceil(x - iXTilda, xTilda); diff --git a/mlir/lib/Dialect/MIOpen/Generator/Conv2dGenerator.cpp b/mlir/lib/Dialect/MIOpen/Generator/Conv2dGenerator.cpp index b7ce744fd519..56817cae8a78 100644 --- a/mlir/lib/Dialect/MIOpen/Generator/Conv2dGenerator.cpp +++ b/mlir/lib/Dialect/MIOpen/Generator/Conv2dGenerator.cpp @@ -1,4 +1,5 @@ #include "mlir/Dialect/MIOpen/Generator/Conv2dGenerator.h" +#include "mlir/Dialect/MIOpen/utility/math.hpp" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" @@ -66,7 +67,7 @@ int Conv2dGenerator::getKernelCount() const { } else if (config.operation == "conv2d") { count = 1; } else if (config.operation == "conv2d_bwd_data") { - count = 1; + count = getBwdDataKernelCount(); } else if (config.operation == "conv2d_bwd_weight") { count = 1; } else if (config.operation == "conv2d_dummy") { @@ -75,6 +76,31 @@ int Conv2dGenerator::getKernelCount() const { return count; } +int Conv2dGenerator::getBwdDataKernelCount() const { + auto gcdStrideDilationH = + math::gcd(config.strideHeight, config.dilationHeight); + auto gcdStrideDilationW = math::gcd(config.strideWidth, config.dilationWidth); + + auto yTilda = config.strideHeight / gcdStrideDilationH; + auto xTilda = config.strideWidth / gcdStrideDilationW; + + auto y = config.filterHeight; + auto x = config.filterWidth; + int count = 0; + for (int gemmId = 0; gemmId < yTilda * xTilda; gemmId++) { + // gemm_k size is different for each GEMM + const auto iYTilda = gemmId / xTilda; + const auto iXTilda = gemmId % xTilda; + + auto yDotSlice = math::integer_divide_ceil(y - iYTilda, yTilda); + auto xDotSlice = math::integer_divide_ceil(x - iXTilda, xTilda); + // gemmK must > 0, otherwise not need to run + if (yDotSlice * xDotSlice > 0) + count++; + } + + return count; +} Type Conv2dGenerator::getDataType(OpBuilder &builder) const { mlir::Type dataType; if (config.dataTypeStr == "f32" || config.dataTypeStr == "fp32") { @@ -168,7 +194,8 @@ Conv2dGenerator::parseConvDims(int64_t batchSize, int64_t groupSize, int64_t inputWidth, int64_t outputChannel, int64_t outputHeight, int64_t outputWidth, int64_t filterHeight, int64_t filterWidth) { - + config.filterHeight = filterHeight; + config.filterWidth = filterWidth; static const std::string filterKeys = "kgcyx"; int64_t filterVals[] = {outputChannel / groupSize, groupSize, inputChannel / groupSize, filterHeight, filterWidth}; @@ -227,7 +254,7 @@ LogicalResult Conv2dGenerator::genConvModule(ModuleOp &module, OpBuilder &builder, int kernel_id) { if (kernel_id == -1) { - kernel_id = config.kernelId; + kernel_id = std::max(config.kernelId, 0); } Type dataType = getDataType(builder); @@ -284,6 +311,7 @@ LogicalResult Conv2dGenerator::genConvModule(ModuleOp &module, } std::vector attributes{ + builder.getNamedAttr("gemm_id", builder.getI32IntegerAttr(kernel_id)), builder.getNamedAttr("arch", builder.getStringAttr(config.arch)), builder.getNamedAttr("num_cu", builder.getI32IntegerAttr(config.num_cu)), @@ -323,43 +351,34 @@ LogicalResult Conv2dGenerator::genConvModule(ModuleOp &module, attributes.push_back( builder.getNamedAttr("xdlopsV2", builder.getBoolAttr(true))); - if (kernel_id > 0) { + if (config.operation == "conv2d") { + auto convOp = builder.create( + builder.getUnknownLoc(), ArrayRef{}, + ValueRange{func.getArgument(0), func.getArgument(1), + func.getArgument(2)}, + attributes); + block->push_front(convOp); + } else if (config.operation == "conv2d_bwd_data") { + auto convOp = builder.create( + builder.getUnknownLoc(), ArrayRef{}, + ValueRange{func.getArgument(0), func.getArgument(1), + func.getArgument(2)}, + attributes); + block->push_front(convOp); + } else if (config.operation == "conv2d_bwd_weight") { + auto convOp = builder.create( + builder.getUnknownLoc(), ArrayRef{}, + ValueRange{func.getArgument(0), func.getArgument(1), + func.getArgument(2)}, + attributes); + block->push_back(convOp); + } else if (config.operation == "conv2d_dummy") { auto convOp = builder.create( builder.getUnknownLoc(), ArrayRef{}, ValueRange{func.getArgument(0), func.getArgument(1), func.getArgument(2)}, attributes); block->push_front(convOp); - } else { - if (config.operation == "conv2d") { - auto convOp = builder.create( - builder.getUnknownLoc(), ArrayRef{}, - ValueRange{func.getArgument(0), func.getArgument(1), - func.getArgument(2)}, - attributes); - block->push_front(convOp); - } else if (config.operation == "conv2d_bwd_data") { - auto convOp = builder.create( - builder.getUnknownLoc(), ArrayRef{}, - ValueRange{func.getArgument(0), func.getArgument(1), - func.getArgument(2)}, - attributes); - block->push_front(convOp); - } else if (config.operation == "conv2d_bwd_weight") { - auto convOp = builder.create( - builder.getUnknownLoc(), ArrayRef{}, - ValueRange{func.getArgument(0), func.getArgument(1), - func.getArgument(2)}, - attributes); - block->push_back(convOp); - } else if (config.operation == "conv2d_dummy") { - auto convOp = builder.create( - builder.getUnknownLoc(), ArrayRef{}, - ValueRange{func.getArgument(0), func.getArgument(1), - func.getArgument(2)}, - attributes); - block->push_front(convOp); - } } auto returnOp = diff --git a/mlir/test/Dialect/MIOpen/lowering_filter_tensor_ckyx_cnhw_knhw.mlir b/mlir/test/Dialect/MIOpen/lowering_filter_tensor_ckyx_cnhw_knhw.mlir index 7e0868175ae4..ec227b46f073 100644 --- a/mlir/test/Dialect/MIOpen/lowering_filter_tensor_ckyx_cnhw_knhw.mlir +++ b/mlir/test/Dialect/MIOpen/lowering_filter_tensor_ckyx_cnhw_knhw.mlir @@ -36,7 +36,8 @@ func @miopen_conv2d_bwd_data_ckyx_cnhw_knhw(%filter : memref<1x8x128x3x3xf32>, % output_layout = ["go", "ko", "no", "ho", "wo"], dilations = [1, 1], strides = [1, 1], - padding = [0, 0, 0, 0] + padding = [0, 0, 0, 0], + gemm_id = 0 } : memref<1x8x128x3x3xf32>, memref<1x8x128x32x32xf32>, memref<1x128x128x30x30xf32> return } diff --git a/mlir/test/Dialect/MIOpen/lowering_gridwise_gemm_position_cyxk_chwn_khwn.mlir b/mlir/test/Dialect/MIOpen/lowering_gridwise_gemm_position_cyxk_chwn_khwn.mlir index b01ad87a59d1..178c6d23f5a9 100644 --- a/mlir/test/Dialect/MIOpen/lowering_gridwise_gemm_position_cyxk_chwn_khwn.mlir +++ b/mlir/test/Dialect/MIOpen/lowering_gridwise_gemm_position_cyxk_chwn_khwn.mlir @@ -36,7 +36,8 @@ func @miopen_conv2d_bwd_data_cyxk_chwn_khwn(%filter : memref<1x8x3x3x128xf32>, % output_layout = ["go", "ko", "ho", "wo", "no"], dilations = [1, 1], strides = [1, 1], - padding = [0, 0, 0, 0] + padding = [0, 0, 0, 0], + gemm_id = 0 } : memref<1x8x3x3x128xf32>, memref<1x8x32x32x128xf32>, memref<1x128x30x30x128xf32> return } diff --git a/mlir/test/Dialect/MIOpen/lowering_input_tensor_cyxk_cnhw_knhw.mlir b/mlir/test/Dialect/MIOpen/lowering_input_tensor_cyxk_cnhw_knhw.mlir index 4ac6a763fd5c..6434a3fb51be 100644 --- a/mlir/test/Dialect/MIOpen/lowering_input_tensor_cyxk_cnhw_knhw.mlir +++ b/mlir/test/Dialect/MIOpen/lowering_input_tensor_cyxk_cnhw_knhw.mlir @@ -36,7 +36,8 @@ func @miopen_conv2d_bwd_data_cyxk_cnhw_knhw(%filter : memref<1x8x3x3x128xf32>, % output_layout = ["go", "ko", "no", "ho", "wo"], dilations = [1, 1], strides = [1, 1], - padding = [0, 0, 0, 0] + padding = [0, 0, 0, 0], + gemm_id = 0 } : memref<1x8x3x3x128xf32>, memref<1x8x128x32x32xf32>, memref<1x128x128x30x30xf32> return } diff --git a/mlir/test/Dialect/MIOpen/lowering_memref_kcyx_nchw_nkhw.mlir b/mlir/test/Dialect/MIOpen/lowering_memref_kcyx_nchw_nkhw.mlir index 9b1f10fb001a..3e75201b0547 100644 --- a/mlir/test/Dialect/MIOpen/lowering_memref_kcyx_nchw_nkhw.mlir +++ b/mlir/test/Dialect/MIOpen/lowering_memref_kcyx_nchw_nkhw.mlir @@ -34,7 +34,8 @@ func @miopen_conv2d_bwd_data_kcyx_nchw_nkhw(%filter : memref<1x128x8x3x3xf32>, % output_layout = ["no", "go", "ko", "ho", "wo"], dilations = [1, 1], strides = [1, 1], - padding = [0, 0, 0, 0] + padding = [0, 0, 0, 0], + gemm_id = 0 } : memref<1x128x8x3x3xf32>, memref<128x1x8x32x32xf32>, memref<128x1x128x30x30xf32> return } diff --git a/mlir/test/Dialect/MIOpen/lowering_output_tensor_kyxc_nhwc_nhwk.mlir b/mlir/test/Dialect/MIOpen/lowering_output_tensor_kyxc_nhwc_nhwk.mlir index 112c5600b154..3fa500b78933 100644 --- a/mlir/test/Dialect/MIOpen/lowering_output_tensor_kyxc_nhwc_nhwk.mlir +++ b/mlir/test/Dialect/MIOpen/lowering_output_tensor_kyxc_nhwc_nhwk.mlir @@ -35,7 +35,8 @@ func @miopen_conv2d_bwd_data_gkyxc_nhwgc_nhwgk(%filter : memref<1x128x3x3x8xf32> output_layout = ["no", "ho", "wo", "go", "ko"], dilations = [1, 1], strides = [1, 1], - padding = [0, 0, 0, 0] + padding = [0, 0, 0, 0], + gemm_id = 0 } : memref<1x128x3x3x8xf32>, memref<128x32x32x1x8xf32>, memref<128x30x30x1x128xf32> return } diff --git a/mlir/test/Dialect/MIOpen/lowering_top_level.mlir b/mlir/test/Dialect/MIOpen/lowering_top_level.mlir index 097d74a43b60..9ea979bcc037 100644 --- a/mlir/test/Dialect/MIOpen/lowering_top_level.mlir +++ b/mlir/test/Dialect/MIOpen/lowering_top_level.mlir @@ -60,7 +60,8 @@ func @miopen_conv2d_bwd_data(%filter : memref<1x128x8x3x3xf32>, %input : memref< output_layout = ["no", "go", "ko", "ho", "wo"], dilations = [1, 1], strides = [1, 1], - padding = [0, 0, 0 ,0] + padding = [0, 0, 0 ,0], + gemm_id = 0 } : memref<1x128x8x3x3xf32>, memref<128x1x8x32x32xf32>, memref<128x1x128x30x30xf32> return } @@ -87,7 +88,8 @@ func @miopen_conv2d_bwd_data_f16(%filter : memref<1x128x8x3x3xf16>, %input : mem output_layout = ["no", "go", "ko", "ho", "wo"], dilations = [1, 1], strides = [1, 1], - padding = [0, 0, 0 ,0] + padding = [0, 0, 0 ,0], + gemm_id = 0 } : memref<1x128x8x3x3xf16>, memref<128x1x8x32x32xf16>, memref<128x1x128x30x30xf16> return } diff --git a/mlir/test/Dialect/MIOpen/translate_cflags_bwd.mlir b/mlir/test/Dialect/MIOpen/translate_cflags_bwd.mlir index 44817eb56bfe..16dc9ac9e786 100644 --- a/mlir/test/Dialect/MIOpen/translate_cflags_bwd.mlir +++ b/mlir/test/Dialect/MIOpen/translate_cflags_bwd.mlir @@ -5,9 +5,9 @@ func @basic_parsing(%filter : memref, %input : memref, %ou arch = "gfx906", num_cu = 64, kernel_algorithm = "backward_data_v4r1", - filter_dimension = [1, 128, 8, 4, 4], + filter_dimension = [1, 128, 32, 4, 4], filter_layout = ["g", "k", "c", "y", "x"], - input_dimension = [128, 1, 8, 32, 32], + input_dimension = [128, 1, 32, 32, 32], input_layout = ["ni", "gi", "ci", "hi", "wi"], output_dimension = [128, 1, 128, 32, 32], output_layout = ["no", "go", "ko", "ho", "wo"], @@ -20,7 +20,7 @@ func @basic_parsing(%filter : memref, %input : memref, %ou return } // CHECK-LABEL: basic_parsing -// CHECK: -DCK_PARAM_PROBLEM_C=8 +// CHECK: -DCK_PARAM_PROBLEM_C=32 // CHECK: -DCK_PARAM_PROBLEM_CONV_DILATION_H=1 // CHECK: -DCK_PARAM_PROBLEM_CONV_DILATION_W=1 // CHECK: -DCK_PARAM_PROBLEM_CONV_STRIDE_H=1 @@ -77,14 +77,14 @@ func @all_params(%filter : memref, %input : memref, %outpu // CHECK: -DCK_PARAM_PROBLEM_X=3 // CHECK: -DCK_PARAM_PROBLEM_Y=3 // CHECK: -DCK_PARAM_TUNABLE_BLOCK_SIZE=256 -// CHECK: -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K=8 -// CHECK: -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M=32 -// CHECK: -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_M=4 -// CHECK: -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM=4 -// CHECK: -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K=8 -// CHECK: -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N=32 -// CHECK: -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_N=4 -// CHECK: -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM=4 +// CHECK: -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K=2 +// CHECK: -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M=128 +// CHECK: -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_M=1 +// CHECK: -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM=1 +// CHECK: -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K=2 +// CHECK: -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N=128 +// CHECK: -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_N=1 +// CHECK: -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM=1 // CHECK: -DCK_PARAM_TUNABLE_GEMM_C_THREAD_COPY_DST_DATA_PER_WRITE_GEMM_N1=1 // CHECK: -DCK_PARAM_TUNABLE_GEMM_K_PER_BLOCK=16 // CHECK: -DCK_PARAM_TUNABLE_GEMM_M_LEVEL0_CLUSTER=4 diff --git a/mlir/test/mlir-miopen-driver/misc_e2e/conv2d_host_validation.mlir b/mlir/test/mlir-miopen-driver/misc_e2e/conv2d_host_validation.mlir index 37fea7ab7f43..a97c2653b33a 100644 --- a/mlir/test/mlir-miopen-driver/misc_e2e/conv2d_host_validation.mlir +++ b/mlir/test/mlir-miopen-driver/misc_e2e/conv2d_host_validation.mlir @@ -10,7 +10,7 @@ // RUN: mlir-miopen-driver --operation conv2d -t f32 -p=false -fil_layout=gkcyx -in_layout=ngchw -out_layout=ngkhw -batchsize=256 -groupsize=32 -in_channels=1024 -out_channels=1024 -in_h=7 -in_w=7 -fil_h=3 -fil_w=3 --dilation_h=1 --dilation_w=1 --padding_h=1 --padding_w=1 --conv_stride_h=1 --conv_stride_w=1 -pv -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET101_NCHW_CONFIG2_FWD -// FIXME: mlir-miopen-driver --operation conv2d_bwd_data -t f32 -p=false -fil_layout=gkcyx -in_layout=ngchw -out_layout=ngkhw -batchsize=256 -groupsize=32 -in_channels=1024 -out_channels=1024 -in_h=7 -in_w=7 -fil_h=3 -fil_w=3 --dilation_h=1 --dilation_w=1 --padding_h=1 --padding_w=1 --conv_stride_h=1 --conv_stride_w=1 -pv -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET101_NCHW_CONFIG2_BWD +// RUN: mlir-miopen-driver --operation conv2d_bwd_data -t f32 -p=false -fil_layout=gkcyx -in_layout=ngchw -out_layout=ngkhw -batchsize=256 -groupsize=32 -in_channels=1024 -out_channels=1024 -in_h=7 -in_w=7 -fil_h=3 -fil_w=3 --dilation_h=1 --dilation_w=1 --padding_h=1 --padding_w=1 --conv_stride_h=1 --conv_stride_w=1 -pv -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET101_NCHW_CONFIG2_BWD // RUN: mlir-miopen-driver --operation conv2d_bwd_weight -t f32 -p=false -fil_layout=gkcyx -in_layout=ngchw -out_layout=ngkhw -batchsize=256 -groupsize=32 -in_channels=1024 -out_channels=1024 -in_h=7 -in_w=7 -fil_h=3 -fil_w=3 --dilation_h=1 --dilation_w=1 --padding_h=1 --padding_w=1 --conv_stride_h=1 --conv_stride_w=1 -pv -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET101_NCHW_CONFIG2_WRW @@ -112,7 +112,7 @@ // RUN: mlir-miopen-driver --operation conv2d -t f32 -p=false -fil_layout=gkyxc -in_layout=nhwgc -out_layout=nhwgk -batchsize=256 -groupsize=32 -in_channels=1024 -out_channels=1024 -in_h=7 -in_w=7 -fil_h=3 -fil_w=3 --dilation_h=1 --dilation_w=1 --padding_h=1 --padding_w=1 --conv_stride_h=1 --conv_stride_w=1 -pv -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET101_NHWC_CONFIG2_FWD -// FIXME: mlir-miopen-driver --operation conv2d_bwd_data -t f32 -p=false -fil_layout=gkyxc -in_layout=nhwgc -out_layout=nhwgk -batchsize=256 -groupsize=32 -in_channels=1024 -out_channels=1024 -in_h=7 -in_w=7 -fil_h=3 -fil_w=3 --dilation_h=1 --dilation_w=1 --padding_h=1 --padding_w=1 --conv_stride_h=1 --conv_stride_w=1 -pv -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET101_NHWC_CONFIG2_BWD +// RUN: mlir-miopen-driver --operation conv2d_bwd_data -t f32 -p=false -fil_layout=gkyxc -in_layout=nhwgc -out_layout=nhwgk -batchsize=256 -groupsize=32 -in_channels=1024 -out_channels=1024 -in_h=7 -in_w=7 -fil_h=3 -fil_w=3 --dilation_h=1 --dilation_w=1 --padding_h=1 --padding_w=1 --conv_stride_h=1 --conv_stride_w=1 -pv -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET101_NHWC_CONFIG2_BWD // RUN: mlir-miopen-driver --operation conv2d_bwd_weight -t f32 -p=false -fil_layout=gkyxc -in_layout=nhwgc -out_layout=nhwgk -batchsize=256 -groupsize=32 -in_channels=1024 -out_channels=1024 -in_h=7 -in_w=7 -fil_h=3 -fil_w=3 --dilation_h=1 --dilation_w=1 --padding_h=1 --padding_w=1 --conv_stride_h=1 --conv_stride_w=1 -pv -c | mlir-rocm-runner --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CHECK_RESNET101_NHWC_CONFIG2_WRW diff --git a/mlir/test/mlir-miopen-driver/populate.mlir b/mlir/test/mlir-miopen-driver/populate.mlir index 83a4a949d47b..914232f93bb1 100644 --- a/mlir/test/mlir-miopen-driver/populate.mlir +++ b/mlir/test/mlir-miopen-driver/populate.mlir @@ -4,12 +4,12 @@ // F32-LABEL: module // F32-NEXT: func @miopen_conv2d_gkcyx_ngchw_ngkhw_0({{.*}}: memref<1x128x8x3x3xf32>, {{.*}}: memref<128x1x8x32x32xf32>, {{.*}}: memref<128x1x128x30x30xf32>) attributes {kernel = 0 : i32} -// F32-NEXT: miopen.conv2d({{.*}}, {{.*}}, {{.*}}) {arch = "{{gfx[0-9]+}}", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = {{[0-9]+}} : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x128x8x3x3xf32>, memref<128x1x8x32x32xf32>, memref<128x1x128x30x30xf32> +// F32-NEXT: miopen.conv2d({{.*}}, {{.*}}, {{.*}}) {arch = "{{gfx[0-9]+}}", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], gemm_id = 0 : i32, input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = {{[0-9]+}} : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x128x8x3x3xf32>, memref<128x1x8x32x32xf32>, memref<128x1x128x30x30xf32> // F16-LABEL: module // F16-NEXT: func @miopen_conv2d_gkcyx_ngchw_ngkhw_0({{.*}}: memref<1x128x8x3x3xf16>, {{.*}}: memref<128x1x8x32x32xf16>, {{.*}}: memref<128x1x128x30x30xf16>) attributes {kernel = 0 : i32} -// F16-NEXT: miopen.conv2d({{.*}}, {{.*}}, {{.*}}) {arch = "{{gfx[0-9]+}}", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = {{[0-9]+}} : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x128x8x3x3xf16>, memref<128x1x8x32x32xf16>, memref<128x1x128x30x30xf16> +// F16-NEXT: miopen.conv2d({{.*}}, {{.*}}, {{.*}}) {arch = "{{gfx[0-9]+}}", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], gemm_id = 0 : i32, input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = {{[0-9]+}} : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x128x8x3x3xf16>, memref<128x1x8x32x32xf16>, memref<128x1x128x30x30xf16> // BF16-LABEL: module // BF16-NEXT: func @miopen_conv2d_gkcyx_ngchw_ngkhw_0({{.*}}: memref<1x128x8x3x3xi16>, {{.*}}: memref<128x1x8x32x32xi16>, {{.*}}: memref<128x1x128x30x30xi16>) attributes {kernel = 0 : i32} -// BF16-NEXT: miopen.conv2d({{.*}}, {{.*}}, {{.*}}) {arch = "{{gfx[0-9]+}}", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = {{[0-9]+}} : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x128x8x3x3xi16>, memref<128x1x8x32x32xi16>, memref<128x1x128x30x30xi16> +// BF16-NEXT: miopen.conv2d({{.*}}, {{.*}}, {{.*}}) {arch = "{{gfx[0-9]+}}", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], gemm_id = 0 : i32, input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = {{[0-9]+}} : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x128x8x3x3xi16>, memref<128x1x8x32x32xi16>, memref<128x1x128x30x30xi16> diff --git a/mlir/test/mlir-miopen-driver/populate_bwd_multi_kernels.mlir b/mlir/test/mlir-miopen-driver/populate_bwd_multi_kernels.mlir new file mode 100644 index 000000000000..af544b48f02c --- /dev/null +++ b/mlir/test/mlir-miopen-driver/populate_bwd_multi_kernels.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-miopen-driver -p=false -fil_layout=gkcyx -in_layout=ngchw -out_layout=ngkhw -batchsize=32 -in_channels=32 -out_channels=32 -in_h=14 -in_w=14 -fil_h=3 -fil_w=3 --dilation_h=1 --dilation_w=1 --padding_h=1 --padding_w=1 --conv_stride_h=2 --conv_stride_w=2 --groupsize=1 -pv --operation=conv2d_bwd_data -miopen-lowering -miopen-affine-transform -miopen-affix-params | FileCheck %s --check-prefix=STRIDE2 + +// RUN: mlir-miopen-driver -p=false -fil_layout=gkyxc -in_layout=nhwgc -out_layout=nhwgk -batchsize=32 -in_channels=32 -out_channels=32 -in_h=14 -in_w=14 -fil_h=3 -fil_w=3 --dilation_h=1 --dilation_w=1 --padding_h=1 --padding_w=1 --conv_stride_h=2 --conv_stride_w=2 --groupsize=1 -pv --operation=conv2d_bwd_data -miopen-lowering -miopen-affine-transform -miopen-affix-params | FileCheck %s --check-prefix=STRIDE2_GKYXC + + +// RUN: mlir-miopen-driver -p=false -fil_layout=gkcyx -in_layout=ngchw -out_layout=ngkhw -batchsize=32 -in_channels=32 -out_channels=32 -in_h=14 -in_w=14 -fil_h=1 -fil_w=1 --dilation_h=1 --dilation_w=1 --padding_h=1 --padding_w=1 --conv_stride_h=2 --conv_stride_w=2 --groupsize=1 -pv --operation=conv2d_bwd_data -miopen-lowering -miopen-affine-transform -miopen-affix-params | FileCheck %s --check-prefix=STRIDE2_1x1 + + +// STRIDE2: {{miopen.gridwise_gemm.*gemm_id = 0 : i32.*matrix_a_source_data_per_read = 1 : i32, matrix_a_source_vector_read_dim = 2 : i32.*matrix_b_source_data_per_read = 1 : i32, matrix_b_source_vector_read_dim = 2 : i32.*}} +// STRIDE2: {{miopen.gridwise_gemm.*gemm_id = 1 : i32.*matrix_a_source_data_per_read = 1 : i32, matrix_a_source_vector_read_dim = 2 : i32.*matrix_b_source_data_per_read = 1 : i32, matrix_b_source_vector_read_dim = 2 : i32.*}} +// STRIDE2: {{miopen.gridwise_gemm.*gemm_id = 2 : i32.*matrix_a_source_data_per_read = 1 : i32, matrix_a_source_vector_read_dim = 2 : i32.*matrix_b_source_data_per_read = 1 : i32, matrix_b_source_vector_read_dim = 2 : i32.*}} +// STRIDE2: {{miopen.gridwise_gemm.*gemm_id = 3 : i32.*matrix_a_source_data_per_read = 1 : i32, matrix_a_source_vector_read_dim = 2 : i32.*matrix_b_source_data_per_read = 1 : i32, matrix_b_source_vector_read_dim = 2 : i32.*}} + + +// STRIDE2_GKYXC: {{miopen.gridwise_gemm.*gemm_id = 0 : i32.*matrix_a_source_data_per_read = 4 : i32, matrix_a_source_vector_read_dim = 2 : i32.*matrix_b_source_data_per_read = 4 : i32, matrix_b_source_vector_read_dim = 1 : i32.*}} +// STRIDE2_GKYXC: {{miopen.gridwise_gemm.*gemm_id = 1 : i32.*matrix_a_source_data_per_read = 4 : i32, matrix_a_source_vector_read_dim = 2 : i32.*matrix_b_source_data_per_read = 4 : i32, matrix_b_source_vector_read_dim = 1 : i32.*}} +// STRIDE2_GKYXC: {{miopen.gridwise_gemm.*gemm_id = 2 : i32.*matrix_a_source_data_per_read = 4 : i32, matrix_a_source_vector_read_dim = 2 : i32.*matrix_b_source_data_per_read = 4 : i32, matrix_b_source_vector_read_dim = 1 : i32.*}} +// STRIDE2_GKYXC: {{miopen.gridwise_gemm.*gemm_id = 3 : i32.*matrix_a_source_data_per_read = 4 : i32, matrix_a_source_vector_read_dim = 2 : i32.*matrix_b_source_data_per_read = 4 : i32, matrix_b_source_vector_read_dim = 1 : i32.*}} + + +// STRIDE2_1x1: {{miopen.gridwise_gemm.*gemm_id = 0 : i32.*matrix_a_source_data_per_read = 4 : i32, matrix_a_source_vector_read_dim = 2 : i32.*matrix_b_source_data_per_read = 4 : i32, matrix_b_source_vector_read_dim = 2 : i32.*}} diff --git a/mlir/test/mlir-miopen-driver/populate_padding.mlir b/mlir/test/mlir-miopen-driver/populate_padding.mlir index dffdb8d3bbb1..c11216573789 100644 --- a/mlir/test/mlir-miopen-driver/populate_padding.mlir +++ b/mlir/test/mlir-miopen-driver/populate_padding.mlir @@ -3,10 +3,10 @@ // Padding_One-LABEL: module // Padding_One-NEXT: func @miopen_conv2d_gkcyx_ngchw_ngkhw_0({{.*}}: memref<1x256x32x1x1xf32>, {{.*}}: memref<32x1x32x14x14xf32>, {{.*}}: memref<32x1x256x14x17xf32>) attributes {kernel = 0 : i32} -// Padding_One-NEXT: miopen.conv2d({{.*}}, {{.*}}, {{.*}}) {arch = "{{gfx[0-9]+}}", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = {{[0-9]+}} : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 1 : i32, 2 : i32], strides = [1 : i32, 1 : i32]} : memref<1x256x32x1x1xf32>, memref<32x1x32x14x14xf32>, memref<32x1x256x14x17xf32> +// Padding_One-NEXT: miopen.conv2d({{.*}}, {{.*}}, {{.*}}) {arch = "{{gfx[0-9]+}}", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], gemm_id = 0 : i32, input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = {{[0-9]+}} : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 1 : i32, 2 : i32], strides = [1 : i32, 1 : i32]} : memref<1x256x32x1x1xf32>, memref<32x1x32x14x14xf32>, memref<32x1x256x14x17xf32> // Padding_Two-LABEL: module // Padding_Two-NEXT: func @miopen_conv2d_gkcyx_ngchw_ngkhw_0({{.*}}: memref<1x256x32x1x1xf32>, {{.*}}: memref<32x1x32x14x14xf32>, {{.*}}: memref<32x1x256x20x17xf32>) attributes {kernel = 0 : i32} -// Padding_Two-NEXT: miopen.conv2d({{.*}}, {{.*}}, {{.*}}) {arch = "{{gfx[0-9]+}}", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = {{[0-9]+}} : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [3 : i32, 3 : i32, 1 : i32, 2 : i32], strides = [1 : i32, 1 : i32]} : memref<1x256x32x1x1xf32>, memref<32x1x32x14x14xf32>, memref<32x1x256x20x17xf32> +// Padding_Two-NEXT: miopen.conv2d({{.*}}, {{.*}}, {{.*}}) {arch = "{{gfx[0-9]+}}", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], gemm_id = 0 : i32, input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = {{[0-9]+}} : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [3 : i32, 3 : i32, 1 : i32, 2 : i32], strides = [1 : i32, 1 : i32]} : memref<1x256x32x1x1xf32>, memref<32x1x32x14x14xf32>, memref<32x1x256x20x17xf32> diff --git a/mlir/test/mlir-miopen-driver/populate_subkernels.mlir b/mlir/test/mlir-miopen-driver/populate_subkernels.mlir index d692d167bb5b..bb62cd1b07c8 100644 --- a/mlir/test/mlir-miopen-driver/populate_subkernels.mlir +++ b/mlir/test/mlir-miopen-driver/populate_subkernels.mlir @@ -5,17 +5,17 @@ // KERNEL0-LABEL: module // KERNEL0-NEXT: func @conv2d_fwd(%arg0: memref<1x1024x1024x1x1xf32>, %arg1: memref<64x1x1024x14x14xf32>, %arg2: memref<64x1x1024x14x14xf32>) attributes {kernel = 0 : i32} { -// KERNEL0-NEXT: miopen.conv2d(%arg0, %arg1, %arg2) {arch = "gfx906", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = 64 : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x1024x1024x1x1xf32>, memref<64x1x1024x14x14xf32>, memref<64x1x1024x14x14xf32> +// KERNEL0-NEXT: miopen.conv2d(%arg0, %arg1, %arg2) {arch = "gfx906", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], gemm_id = 0 : i32, input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = 64 : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x1024x1024x1x1xf32>, memref<64x1x1024x14x14xf32>, memref<64x1x1024x14x14xf32> // KERNEL1-LABEL: module // KERNEL1-NEXT: func @conv2d_fwd(%arg0: memref<1x1024x1024x1x1xf32>, %arg1: memref<64x1x1024x14x14xf32>, %arg2: memref<64x1x1024x14x14xf32>) attributes {kernel = 1 : i32} { -// KERNEL1-NEXT: miopen.conv2d_dummy(%arg0, %arg1, %arg2) {arch = "gfx906", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = 64 : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x1024x1024x1x1xf32>, memref<64x1x1024x14x14xf32>, memref<64x1x1024x14x14xf32> +// KERNEL1-NEXT: miopen.conv2d(%arg0, %arg1, %arg2) {arch = "gfx906", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], gemm_id = 1 : i32, input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = 64 : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x1024x1024x1x1xf32>, memref<64x1x1024x14x14xf32>, memref<64x1x1024x14x14xf32> // KERNEL2-LABEL: module // KERNEL2-NEXT: func @conv2d_fwd(%arg0: memref<1x1024x1024x1x1xf32>, %arg1: memref<64x1x1024x14x14xf32>, %arg2: memref<64x1x1024x14x14xf32>) attributes {kernel = 2 : i32} { -// KERNEL2-NEXT: miopen.conv2d_dummy(%arg0, %arg1, %arg2) {arch = "gfx906", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = 64 : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x1024x1024x1x1xf32>, memref<64x1x1024x14x14xf32>, memref<64x1x1024x14x14xf32> +// KERNEL2-NEXT: miopen.conv2d(%arg0, %arg1, %arg2) {arch = "gfx906", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], gemm_id = 2 : i32, input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = 64 : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x1024x1024x1x1xf32>, memref<64x1x1024x14x14xf32>, memref<64x1x1024x14x14xf32> // KERNEL3-LABEL: module // KERNEL3-NEXT: func @conv2d_fwd(%arg0: memref<1x1024x1024x1x1xf32>, %arg1: memref<64x1x1024x14x14xf32>, %arg2: memref<64x1x1024x14x14xf32>) attributes {kernel = 3 : i32} { -// KERNEL3-NEXT: miopen.conv2d_dummy(%arg0, %arg1, %arg2) {arch = "gfx906", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = 64 : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x1024x1024x1x1xf32>, memref<64x1x1024x14x14xf32>, memref<64x1x1024x14x14xf32> +// KERNEL3-NEXT: miopen.conv2d(%arg0, %arg1, %arg2) {arch = "gfx906", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], gemm_id = 3 : i32, input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = 64 : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x1024x1024x1x1xf32>, memref<64x1x1024x14x14xf32>, memref<64x1x1024x14x14xf32> diff --git a/mlir/test/mlir-miopen-lib/populate_bwd.mlir b/mlir/test/mlir-miopen-lib/populate_bwd.mlir index 121fdedc8345..fdfd596a9429 100644 --- a/mlir/test/mlir-miopen-lib/populate_bwd.mlir +++ b/mlir/test/mlir-miopen-lib/populate_bwd.mlir @@ -1,7 +1,7 @@ // RUN: mlir-miopen-lib-test --args " --operation conv2d_bwd_data --arch gfx906 --num_cu 64 --in_type fp32 --fil_type fp32 --out_type fp32 --fil_layout GNCHW --in_layout NGCHW --out_layout NGCHW --batchsize 64 --in_channels 1024 --out_channels 1024 --in_h 14 --in_w 14 --out_h 14 --out_w 14 --fil_h 1 --fil_w 1 --dilation_h 1 --dilation_w 1 --conv_stride_h 1 --conv_stride_w 1 --padding_h 0 --padding_w 0 --kernel_name conv2d_bwd --groupsize 1" --option cflags | FileCheck %s --check-prefix=CFLAGS // RUN: mlir-miopen-lib-test --args " --operation conv2d_bwd_data --arch gfx906 --num_cu 64 --in_type fp32 --fil_type fp32 --out_type fp32 --fil_layout GNCHW --in_layout NGCHW --out_layout NGCHW --batchsize 64 --in_channels 1024 --out_channels 1024 --in_h 14 --in_w 14 --out_h 14 --out_w 14 --fil_h 1 --fil_w 1 --dilation_h 1 --dilation_w 1 --conv_stride_h 1 --conv_stride_w 1 --padding_h 0 --padding_w 0 --kernel_name foo --groupsize 1" --option source | FileCheck %s --check-prefix=SOURCE // RUN: mlir-miopen-lib-test --args " --operation conv2d_bwd_data --arch gfx906 --num_cu 64 --in_type fp32 --fil_type fp32 --out_type fp32 --fil_layout GNCHW --in_layout NGCHW --out_layout NGCHW --batchsize 64 --in_channels 1024 --out_channels 1024 --in_h 14 --in_w 14 --out_h 14 --out_w 14 --fil_h 1 --fil_w 1 --dilation_h 1 --dilation_w 1 --conv_stride_h 1 --conv_stride_w 1 --padding_h 0 --padding_w 0 --kernel_name bar --groupsize 1" --option header | FileCheck %s --check-prefix=HEADER -// RUN: mlir-miopen-lib-test --args " --operation conv2d_bwd_data --arch gfx906 --num_cu 64 --in_type fp32 --fil_type fp32 --out_type fp32 --fil_layout GNCHW --in_layout NGCHW --out_layout NGCHW --batchsize 64 --in_channels 1024 --out_channels 1024 --in_h 14 --in_w 14 --out_h 14 --out_w 14 --fil_h 1 --fil_w 1 --dilation_h 1 --dilation_w 1 --conv_stride_h 1 --conv_stride_w 1 --padding_h 0 --padding_w 0 --kernel_name conv2d_nchw_nchw_nchw --groupsize 1" --option bin | FileCheck %s --check-prefix=BIN +// RUN: mlir-miopen-lib-test --args " --operation conv2d_bwd_data --arch gfx906 --num_cu 64 --in_type fp32 --fil_type fp32 --out_type fp32 --fil_layout GNCHW --in_layout NGCHW --out_layout NGCHW --batchsize 64 --in_channels 1024 --out_channels 1024 --in_h 14 --in_w 14 --out_h 14 --out_w 14 --fil_h 3 --fil_w 3 --dilation_h 1 --dilation_w 1 --conv_stride_h 2 --conv_stride_w 2 --padding_h 0 --padding_w 0 --kernel_name conv2d_nchw_nchw_nchw --groupsize 1" --option bin | FileCheck %s --check-prefix=BIN // RUN: mlir-miopen-lib-test --args " --operation conv2d_bwd_data --arch gfx906 --num_cu 64 --in_type fp32 --fil_type fp32 --out_type fp32 --fil_layout GNCHW --in_layout NGCHW --out_layout NGCHW --batchsize 64 --in_channels 1024 --out_channels 1024 --in_h 14 --in_w 14 --out_h 14 --out_w 14 --fil_h 1 --fil_w 1 --dilation_h 1 --dilation_w 1 --conv_stride_h 1 --conv_stride_w 1 --padding_h 0 --padding_w 0 --kernel_name bar --groupsize 1" --option tuningparams | FileCheck %s --check-prefix=TUNING // RUN: mlir-miopen-driver --conv-config "--operation conv2d_bwd_data --arch gfx906 --num_cu 64 --in_type fp32 --fil_type fp32 --out_type fp32 --fil_layout GNCHW --in_layout NGCHW --out_layout NGCHW --batchsize 64 --in_channels 1024 --out_channels 1024 --in_h 14 --in_w 14 --out_h 14 --out_w 14 --fil_h 1 --fil_w 1 --dilation_h 1 --dilation_w 1 --conv_stride_h 1 --conv_stride_w 1 --padding_h 0 --padding_w 0 --kernel_name bar --groupsize 1 " | FileCheck %s --check-prefix=DRIVER @@ -9,7 +9,10 @@ // SOURCE: void mlir_gen_igemm_conv2d_cpp_v4r1_bwd // HEADER: struct MlirGenIgemmConv2dV1r1Bwd // BIN: ELF +// BIN: ELF +// BIN: ELF +// BIN: ELF // TUNING: globalSize{{.*}}localSize{{.*}} -// DRIVER: miopen.conv2d_bwd_data(%arg0, %arg1, %arg2) {arch = "gfx906", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = 64 : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x1024x1024x1x1xf32>, memref<64x1x1024x14x14xf32>, memref<64x1x1024x14x14xf32> +// DRIVER: miopen.conv2d_bwd_data(%arg0, %arg1, %arg2) {arch = "gfx906", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], gemm_id = 0 : i32, input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = 64 : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x1024x1024x1x1xf32>, memref<64x1x1024x14x14xf32>, memref<64x1x1024x14x14xf32> diff --git a/mlir/test/mlir-miopen-lib/populate_bww.mlir b/mlir/test/mlir-miopen-lib/populate_bww.mlir index ff431eccb815..0ef0810c891b 100644 --- a/mlir/test/mlir-miopen-lib/populate_bww.mlir +++ b/mlir/test/mlir-miopen-lib/populate_bww.mlir @@ -12,4 +12,4 @@ // BIN: ELF // TUNING: globalSize{{.*}}localSize{{.*}} -// DRIVER: miopen.conv2d_bwd_weight(%arg0, %arg1, %arg2) {arch = "gfx906", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = 64 : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x1024x1024x1x1xf32>, memref<64x1x1024x14x14xf32>, memref<64x1x1024x14x14xf32> +// DRIVER: miopen.conv2d_bwd_weight(%arg0, %arg1, %arg2) {arch = "gfx906", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], gemm_id = 0 : i32, input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = 64 : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x1024x1024x1x1xf32>, memref<64x1x1024x14x14xf32>, memref<64x1x1024x14x14xf32> diff --git a/mlir/test/mlir-miopen-lib/populate_fw.mlir b/mlir/test/mlir-miopen-lib/populate_fw.mlir index 8b8b8c05f966..d9ab15da6c15 100644 --- a/mlir/test/mlir-miopen-lib/populate_fw.mlir +++ b/mlir/test/mlir-miopen-lib/populate_fw.mlir @@ -11,6 +11,6 @@ // BIN: ELF // TUNING: globalSize{{.*}}localSize{{.*}} -// DRIVER: miopen.conv2d(%arg0, %arg1, %arg2) {arch = "gfx906", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = 64 : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x1024x1024x1x1xf32>, memref<64x1x1024x14x14xf32>, memref<64x1x1024x14x14xf32> +// DRIVER: miopen.conv2d(%arg0, %arg1, %arg2) {arch = "gfx906", dilations = [1 : i32, 1 : i32], filter_layout = ["g", "k", "c", "y", "x"], gemm_id = 0 : i32, input_layout = ["ni", "gi", "ci", "hi", "wi"], num_cu = 64 : i32, output_layout = ["no", "go", "ko", "ho", "wo"], padding = [0 : i32, 0 : i32, 0 : i32, 0 : i32], strides = [1 : i32, 1 : i32]} : memref<1x1024x1024x1x1xf32>, memref<64x1x1024x14x14xf32>, memref<64x1x1024x14x14xf32> diff --git a/mlir/tools/mlir-miopen-driver/mlir-miopen-driver.cpp b/mlir/tools/mlir-miopen-driver/mlir-miopen-driver.cpp index c0aa9e2f3019..50eb921d2627 100644 --- a/mlir/tools/mlir-miopen-driver/mlir-miopen-driver.cpp +++ b/mlir/tools/mlir-miopen-driver/mlir-miopen-driver.cpp @@ -2391,7 +2391,7 @@ int main(int argc, char **argv) { // Populate the module. if (!populateCpuConvolution.getValue()) { if (genConfig.kernelId < 0) { - // generate all sub-kernels + // generate all sub-kernels, and get corresponding gemmId int kernelCount = conv2dGenerator.getKernelCount(); auto knSize = genConfig.kernelName.size(); std::string kernelBaseName = genConfig.kernelName.substr(0, knSize - 1); diff --git a/mlir/tools/mlir-miopen-lib/mlir-miopen-lib-test.cpp b/mlir/tools/mlir-miopen-lib/mlir-miopen-lib-test.cpp index baf303a44b9e..99e61e387366 100644 --- a/mlir/tools/mlir-miopen-lib/mlir-miopen-lib-test.cpp +++ b/mlir/tools/mlir-miopen-lib/mlir-miopen-lib-test.cpp @@ -24,6 +24,8 @@ int main(int argc, char **argv) { cl::ParseCommandLineOptions(argc, argv, "MLIR MIOpen Dialect driver\n"); MiirStatus status = MIIR_SUCCESS; + // save args + std::string parameters = args.getValue(); MiirHandle handle = miirCreateHandle(args.getValue().c_str()); @@ -59,32 +61,41 @@ int main(int argc, char **argv) { << ", localSize=" << localSize << std::endl; } else if (option.getValue() == "bin") { - status = miirLowerBin(handle); - if (status != MIIR_SUCCESS) { - return status; - } + int count = miirGetKernelCount(handle); + for (int i = 0; i < count; i++) { + auto arguments = parameters + " --kernel_id " + std::to_string(i); - size_t size = 0; - status = miirBufferGet(handle, nullptr, &size); - if (status != MIIR_SUCCESS) { - return status; - } - std::vector buffer(size); - status = miirBufferGet(handle, buffer.data(), &size); - if (status != MIIR_SUCCESS) { - return status; - } - std::for_each(buffer.begin(), buffer.end(), - [](char &c) { std::cout << c; }); - std::cout << std::endl; + MiirHandle newHandle = miirCreateHandle(arguments.c_str()); - size_t globalSize, localSize; - status = miirGetExecutionDims(handle, &globalSize, &localSize); - if (status != MIIR_SUCCESS) { - return status; + status = miirLowerBin(newHandle); + if (status != MIIR_SUCCESS) { + return status; + } + + size_t size = 0; + status = miirBufferGet(newHandle, nullptr, &size); + if (status != MIIR_SUCCESS) { + return status; + } + std::vector buffer(size); + status = miirBufferGet(newHandle, buffer.data(), &size); + if (status != MIIR_SUCCESS) { + return status; + } + std::for_each(buffer.begin(), buffer.end(), + [](char &c) { std::cout << c; }); + std::cout << std::endl; + + size_t globalSize, localSize; + status = miirGetExecutionDims(newHandle, &globalSize, &localSize); + if (status != MIIR_SUCCESS) { + return status; + } + std::cout << "ExecutionDims - globalSize=" << globalSize + << ", localSize=" << localSize << ", kernelId = " << i + << std::endl; + miirDestroyHandle(newHandle); } - std::cout << "ExecutionDims - globalSize=" << globalSize - << ", localSize=" << localSize << std::endl; } miirDestroyHandle(handle);