Skip to content

Commit

Permalink
calculate gemmid from kernel id
Browse files Browse the repository at this point in the history
  • Loading branch information
ltqin committed Jun 2, 2021
1 parent af92ead commit 636f052
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
23 changes: 22 additions & 1 deletion mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -2029,7 +2029,28 @@ struct Conv2DRewritePattern : public OpRewritePattern<T> {
auto hTildaSlice = iHTildaRight - iHTildaLeft;
auto wTildaSlice = iWTildaRight - iWTildaLeft;

auto gemmId = gemmIdAttr.getInt();
auto getGemmId = [&](int kernelId) {
// kernelId 0 must be gemmId 0
if (kernelId <= 0)
return 0;

llvm::SmallVector<int> gemmIds;
for (int gemmId = 0; gemmId < yTilda * xTilda; gemmId++) {
// gemm_k size is different for each GEMM
const auto iYTilda = gemmId / xTilda;
const auto iXTilda = gemmId % xTilda;

auto yDotSlice = math::integer_divide_ceil(y - iYTilda, yTilda);
auto xDotSlice = math::integer_divide_ceil(x - iXTilda, xTilda);
// gemmK must > 0, otherwise not need to run
if (yDotSlice * xDotSlice > 0) {
gemmIds.push_back(gemmId);
}
}
assert(gemmIds.size() > kernelId);
return gemmIds[kernelId];
};
auto gemmId = getGemmId(gemmIdAttr.getInt());
auto iYTilda = gemmId / xTilda;
auto iXTilda = gemmId % xTilda;
auto yDotSlice = math::integer_divide_ceil(y - iYTilda, yTilda);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/MIOpen/Generator/Conv2dGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ LogicalResult Conv2dGenerator::genConvModule(ModuleOp &module,
OpBuilder &builder,
int kernel_id) {
if (kernel_id == -1) {
kernel_id = config.kernelId;
kernel_id = std::max(config.kernelId, 0);
}

Type dataType = getDataType(builder);
Expand Down

0 comments on commit 636f052

Please sign in to comment.