Skip to content

Commit

Permalink
Update according to comments.
Browse files Browse the repository at this point in the history
Signed-off-by: Alan Li <[email protected]>
  • Loading branch information
lialan committed Jul 16, 2024
1 parent 6953441 commit a4698bf
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -438,31 +438,24 @@ struct ReverseOpConversion final
Value input = op.getOperand();
auto inputTy = cast<ShapedType>(input.getType());
auto resultTy = cast<ShapedType>(op.getType());
auto dims = op.getDimensions();
ArrayRef<int64_t> dims = op.getDimensions();
Location loc = op.getLoc();

SmallVector<Value> dynDims;
for (int i = 0; i < inputTy.getRank(); i++) {
if (inputTy.isDynamicDim(i)) {
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
}
}
int64_t inputTyRank = inputTy.getRank();

// First fill the output buffer with the init value.
auto emptyTensor = rewriter
.create<tensor::EmptyOp>(loc, inputTy.getShape(),
inputTy.getElementType(),
ArrayRef<Value>({dynDims}))
.getResult();
SmallVector<AffineMap, 2> affineMaps = {
SmallVector<OpFoldResult> inputMixedSizes =
tensor::getMixedSizes(rewriter, loc, input);
auto emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, inputMixedSizes, inputTy.getElementType());
SmallVector<AffineMap> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())};

rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
llvm::SmallVector<Value> indices;
for (unsigned int i = 0; i < inputTy.getRank(); i++) {
for (unsigned int i = 0; i < inputTyRank; i++) {
Value index =
rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
if (std::find(dims.begin(), dims.end(), i) != dims.end()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,12 +540,12 @@ func.func @reverse_multi_dim(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
} : (tensor<?x?xi32>) -> tensor<?x?xi32>
return %0 : tensor<?x?xi32>
}
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[IN]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[IN]], %[[C1]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor<?x?xi32>
// CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<?x?xi32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[D:.+]] = tensor.dim %[[IN]], %[[C0]] : tensor<?x?xi32>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[D0:.+]] = tensor.dim %[[IN]], %[[C1]] : tensor<?x?xi32>
// CHECK: %[[INIT:.+]] = tensor.empty(%[[D]], %[[D0]]) : tensor<?x?xi32>
// CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<?x?xi32>) {

// First reverse dimension
// CHECK: %[[IDX0:.+]] = linalg.index 0 : index
Expand Down

0 comments on commit a4698bf

Please sign in to comment.