Skip to content

Commit

Permalink
[Stablehlo] Add stablehlo support for aten.abs (llvm#2068)
Browse files Browse the repository at this point in the history
Co-authored-by: AmosLewis <[email protected]>
  • Loading branch information
Chi_Liu and AmosLewis authored May 9, 2023
1 parent c7a24c4 commit 51e0a2c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 0 deletions.
1 change: 1 addition & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@
"ElementwiseSubScalarFloatModule_basic",
"ElementwiseSubScalarIntModule_basic",
"ElementwiseWhereScalarModule_basic",
"ElementwiseAbsModule_basic",
"EmbeddingModule1DIndices_basic",
"EmbeddingModuleI32Static_basic",
"EmbeddingModuleI32_basic",
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1451,6 +1451,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp);
INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp);
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp);
INSERT_UNARY_PATTERN(AtenAbsOp, stablehlo::AbsOp);
#undef INSERT_UNARY_PATTERN

#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \
Expand Down
13 changes: 13 additions & 0 deletions test/Conversion/TorchToStablehlo/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,16 @@ func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.abs(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[15,15],si64>) -> !torch.vtensor<[15,15],si64> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[15,15],si64> -> tensor<15x15xi64>
// CHECK: %[[VAL_2:.*]] = stablehlo.abs %[[VAL_1]] : tensor<15x15xi64>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<15x15xi64> -> !torch.vtensor<[15,15],si64>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[15,15],si64>
// CHECK: }
func.func @torch.aten.abs(%arg0: !torch.vtensor<[15,15],si64>) -> !torch.vtensor<[15,15],si64>{
%0 = torch.aten.abs %arg0 : !torch.vtensor<[15,15],si64> -> !torch.vtensor<[15,15],si64>
return %0 : !torch.vtensor<[15,15],si64>
}

0 comments on commit 51e0a2c

Please sign in to comment.