From 796888ac6a4e97a19f15d92394e189a67d777778 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Tue, 16 Jul 2024 02:55:57 +0000 Subject: [PATCH] Update according to comments. Signed-off-by: Alan Li --- .../Conversion/StableHLOToLinalgExt.cpp | 21 +++++++------------ .../test/stablehlo_to_linalg_ext.mlir | 16 ++++++++------ 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp index 0f32f678bbf85..1b2884bd569b5 100644 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp @@ -440,21 +440,14 @@ struct ReverseOpConversion final auto resultTy = cast(op.getType()); auto dims = op.getDimensions(); Location loc = op.getLoc(); - - SmallVector dynDims; - for (int i = 0; i < inputTy.getRank(); i++) { - if (inputTy.isDynamicDim(i)) { - dynDims.push_back(rewriter.create(loc, input, i)); - } - } + auto inputTyRank = inputTy.getRank(); // First fill the output buffer with the init value. - auto emptyTensor = rewriter - .create(loc, inputTy.getShape(), - inputTy.getElementType(), - ArrayRef({dynDims})) - .getResult(); - SmallVector affineMaps = { + SmallVector inputMixedSizes = + tensor::getMixedSizes(rewriter, loc, input); + auto emptyTensor = rewriter.create( + loc, inputMixedSizes, inputTy.getElementType()); + SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; rewriter.replaceOpWithNewOp( @@ -462,7 +455,7 @@ struct ReverseOpConversion final getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { llvm::SmallVector indices; - for (unsigned int i = 0; i < inputTy.getRank(); i++) { + for (unsigned int i = 0; i < inputTyRank; i++) { Value index = rewriter.create(nestedLoc, i).getResult(); if (std::find(dims.begin(), dims.end(), i) != dims.end()) { diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir index 7c2266b7d51b9..59dc41d3fe4a8 100644 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir +++ b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir @@ -540,12 +540,16 @@ func.func @reverse_multi_dim(%arg0: tensor) -> tensor { } : (tensor) -> tensor return %0 : tensor } -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[IN]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[IN]], %[[C1]] -// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor -// CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor) { +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[D:.+]] = tensor.dim %[[IN]], %[[C0]] : tensor +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[D0:.+]] = tensor.dim %[[IN]], %[[C1]] : tensor +// CHECK: %[[C0_1:.+]] = arith.constant 0 : index +// CHECK: %[[D2:.+]] = tensor.dim %[[IN]], %[[C0_1]] +// CHECK: %[[C1_3:.+]] = arith.constant 1 : index +// CHECK: %[[D4:.+]] = tensor.dim %[[IN]], %[[C1_3]] +// CHECK: %[[INIT:.+]] = tensor.empty(%[[D2]], %[[D4]]) : tensor +// CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor) { // First reverse dimension // CHECK: %[[IDX0:.+]] = linalg.index 0 : index