Skip to content

Commit

Permalink
Update according to comments.
Browse files Browse the repository at this point in the history
Signed-off-by: Alan Li <[email protected]>
  • Loading branch information
lialan committed Jul 16, 2024
1 parent 923869a commit 10193a9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -431,38 +431,39 @@ struct ReverseOpConversion final
LogicalResult
matchAndRewrite(mlir::stablehlo::ReverseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto ty = dyn_cast<RankedTensorType>(adaptor.getOperands()[0].getType());
const auto ty =
dyn_cast<RankedTensorType>(adaptor.getOperands()[0].getType());
if (!ty)
return failure();

Value input = op.getOperand();
auto inputTy = cast<ShapedType>(input.getType());
auto resultTy = cast<ShapedType>(op.getType());
auto dims = op.getDimensions();
Location loc = op.getLoc();
const auto inputTy = cast<ShapedType>(input.getType());
const auto resultTy = cast<ShapedType>(op.getType());
const auto dims = op.getDimensions();
const Location loc = op.getLoc();
const auto inputTyRank = inputTy.getRank();

SmallVector<Value> dynDims;
for (int i = 0; i < inputTy.getRank(); i++) {
for (int i = 0; i < inputTyRank; i++) {
if (inputTy.isDynamicDim(i)) {
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
}
}

// First fill the output buffer with the init value.
auto emptyTensor = rewriter
.create<tensor::EmptyOp>(loc, inputTy.getShape(),
inputTy.getElementType(),
ArrayRef<Value>({dynDims}))
.getResult();
SmallVector<AffineMap, 2> affineMaps = {
SmallVector<OpFoldResult> inputMixedSizes =
tensor::getMixedSizes(rewriter, loc, input);
auto emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, inputMixedSizes, inputTy.getElementType());
SmallVector<AffineMap> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())};

rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
llvm::SmallVector<Value> indices;
for (unsigned int i = 0; i < inputTy.getRank(); i++) {
for (unsigned int i = 0; i < inputTyRank; i++) {
Value index =
rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
if (std::find(dims.begin(), dims.end(), i) != dims.end()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,12 +540,16 @@ func.func @reverse_multi_dim(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
} : (tensor<?x?xi32>) -> tensor<?x?xi32>
return %0 : tensor<?x?xi32>
}
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[IN]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[IN]], %[[C1]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor<?x?xi32>
// CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<?x?xi32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[D:.+]] = tensor.dim %[[IN]], %[[C0]] : tensor<?x?xi32>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[D0:.+]] = tensor.dim %[[IN]], %[[C1]] : tensor<?x?xi32>
// CHECK: %[[C0_1:.+]] = arith.constant 0 : index
// CHECK: %[[D2:.+]] = tensor.dim %[[IN]], %[[C0_1]]
// CHECK: %[[C1_3:.+]] = arith.constant 1 : index
// CHECK: %[[D4:.+]] = tensor.dim %[[IN]], %[[C1_3]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[D2]], %[[D4]]) : tensor<?x?xi32>
// CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<?x?xi32>) {

// First reverse dimension
// CHECK: %[[IDX0:.+]] = linalg.index 0 : index
Expand Down

0 comments on commit 10193a9

Please sign in to comment.