From 636f05258d9420a8c211a936ada791dd39e0e305 Mon Sep 17 00:00:00 2001 From: ltqin Date: Wed, 2 Jun 2021 14:57:59 +0800 Subject: [PATCH] calculate gemmid from kernel id --- .../mlir/Dialect/MIOpen/LowerMIOpenOps.h | 23 ++++++++++++++++++- .../MIOpen/Generator/Conv2dGenerator.cpp | 2 +- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h index b7a6969994ef..9ea23715a64e 100644 --- a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h +++ b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h @@ -2029,7 +2029,28 @@ struct Conv2DRewritePattern : public OpRewritePattern { auto hTildaSlice = iHTildaRight - iHTildaLeft; auto wTildaSlice = iWTildaRight - iWTildaLeft; - auto gemmId = gemmIdAttr.getInt(); + auto getGemmId = [&](int kernelId) { + // kernelId 0 must be gemmId 0 + if (kernelId <= 0) + return 0; + + llvm::SmallVector gemmIds; + for (int gemmId = 0; gemmId < yTilda * xTilda; gemmId++) { + // gemm_k size is different for each GEMM + const auto iYTilda = gemmId / xTilda; + const auto iXTilda = gemmId % xTilda; + + auto yDotSlice = math::integer_divide_ceil(y - iYTilda, yTilda); + auto xDotSlice = math::integer_divide_ceil(x - iXTilda, xTilda); + // gemmK must > 0, otherwise not need to run + if (yDotSlice * xDotSlice > 0) { + gemmIds.push_back(gemmId); + } + } + assert(gemmIds.size() > kernelId); + return gemmIds[kernelId]; + }; + auto gemmId = getGemmId(gemmIdAttr.getInt()); auto iYTilda = gemmId / xTilda; auto iXTilda = gemmId % xTilda; auto yDotSlice = math::integer_divide_ceil(y - iYTilda, yTilda); diff --git a/mlir/lib/Dialect/MIOpen/Generator/Conv2dGenerator.cpp b/mlir/lib/Dialect/MIOpen/Generator/Conv2dGenerator.cpp index 518e72dd4100..56817cae8a78 100644 --- a/mlir/lib/Dialect/MIOpen/Generator/Conv2dGenerator.cpp +++ b/mlir/lib/Dialect/MIOpen/Generator/Conv2dGenerator.cpp @@ -254,7 +254,7 @@ LogicalResult Conv2dGenerator::genConvModule(ModuleOp &module, OpBuilder &builder, int kernel_id) { if (kernel_id == -1) { - kernel_id = config.kernelId; + kernel_id = std::max(config.kernelId, 0); } Type dataType = getDataType(builder);