From 646ac71ec78e1b34b06f402b261d46eec7f879d4 Mon Sep 17 00:00:00 2001 From: Ashay Rane Date: Tue, 30 Aug 2022 14:44:00 -0500 Subject: [PATCH] build: update llvm tag to 00d648bd (#1307) - Update MHLO commit to build with LLVM commit hash 00d648bd - Update TorchToMhlo code to work with Stablehlo - Re-enabled two failing TOSA tests, thus resolving Github Issue #1231 --- e2e_testing/xfail_sets.py | 2 ++ externals/llvm-project | 2 +- externals/mlir-hlo | 2 +- lib/Conversion/TorchToMhlo/Basic.cpp | 31 ++++++++++---------- lib/Conversion/TorchToMhlo/CMakeLists.txt | 4 +-- lib/Conversion/TorchToMhlo/Linear.cpp | 2 +- lib/Conversion/TorchToMhlo/Pooling.cpp | 2 +- lib/Conversion/TorchToMhlo/TorchToMhlo.cpp | 2 +- test/Conversion/TorchToMhlo/elementwise.mlir | 12 ++++---- 9 files changed, 31 insertions(+), 28 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 37e6f06d8d490..1cfa3652e4cd8 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -163,6 +163,7 @@ "ElementwiseBinaryModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseExpModule_basic", + "ElementwiseReluModule_basic", "ElementwiseFloorModule_basic", "ElementwiseLogModule_basic", "ElementwiseBinaryStaticShapeModule_basic", @@ -237,6 +238,7 @@ "ElementwiseFlattenBroadcastModule_basic", "SquareModule_basic", "MaxPool2dStaticModule_basic", + "ResNet18StaticModule_basic", "NativeLayerNormModule4D_basic", "LayerNormNormalizeOverAllDimsModule_basic", "PermuteModule_basic", diff --git a/externals/llvm-project b/externals/llvm-project index 2dde4ba63974d..00d648bdb5a8b 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 2dde4ba63974daf59f8ce5c346505f194f920131 +Subproject commit 00d648bdb5a8b71785269b4851b651c883de2cd9 diff --git a/externals/mlir-hlo b/externals/mlir-hlo index 9c49473d80a86..305a2f2522966 160000 --- a/externals/mlir-hlo +++ b/externals/mlir-hlo @@ -1 +1 @@ -Subproject commit 9c49473d80a8667e94232ddb5ed60a1a9d8ad266 +Subproject commit 305a2f25229660ea789bf70ed8e7336227f6228a diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index 81ba3dfbdef42..0fafbf1336cdd 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -12,10 +12,11 @@ #include "../PassDetail.h" #include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" -#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/utils/hlo_utils.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "stablehlo/dialect/ChloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -291,33 +292,33 @@ class ConvertAtenCompareOp : public OpConversionPattern { // TODO: what is the PyTorch default type promotion? rhs = mhlo::promoteType(rewriter, rhs, lhsTy); - mhlo::ComparisonTypeAttr compareTypeAttr; - mhlo::ComparisonDirectionAttr compareDirectionAttr; + chlo::ComparisonTypeAttr compareTypeAttr; + chlo::ComparisonDirectionAttr compareDirectionAttr; if (lhsElemTy.isa()) { - compareTypeAttr = mhlo::ComparisonTypeAttr::get( - op->getContext(), mhlo::ComparisonType::FLOAT); + compareTypeAttr = chlo::ComparisonTypeAttr::get( + op->getContext(), chlo::ComparisonType::FLOAT); } else if (lhsElemTy.isa()) { - compareTypeAttr = mhlo::ComparisonTypeAttr::get( - op->getContext(), mhlo::ComparisonType::SIGNED); + compareTypeAttr = chlo::ComparisonTypeAttr::get( + op->getContext(), chlo::ComparisonType::SIGNED); } if (std::is_same() || std::is_same()) { - compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( - op->getContext(), mhlo::ComparisonDirection::LT); + compareDirectionAttr = chlo::ComparisonDirectionAttr::get( + op->getContext(), chlo::ComparisonDirection::LT); } else if (std::is_same() || std::is_same()) { - compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( - op->getContext(), mhlo::ComparisonDirection::GT); + compareDirectionAttr = chlo::ComparisonDirectionAttr::get( + op->getContext(), chlo::ComparisonDirection::GT); } else if (std::is_same() || std::is_same()) { - compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( - op->getContext(), mhlo::ComparisonDirection::EQ); + compareDirectionAttr = chlo::ComparisonDirectionAttr::get( + op->getContext(), chlo::ComparisonDirection::EQ); } else if (std::is_same() || std::is_same()) { - compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( - op->getContext(), mhlo::ComparisonDirection::NE); + compareDirectionAttr = chlo::ComparisonDirectionAttr::get( + op->getContext(), chlo::ComparisonDirection::NE); } DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp( diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt index b3468cb7be4d2..39d956fddb176 100644 --- a/lib/Conversion/TorchToMhlo/CMakeLists.txt +++ b/lib/Conversion/TorchToMhlo/CMakeLists.txt @@ -13,7 +13,6 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo DEPENDS MhloDialect - ChloDialect MhloToLinalg MLIRMhloPassIncGen TorchMLIRConversionPassIncGen @@ -22,11 +21,12 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo Core LINK_LIBS PUBLIC + ChloOps MLIRIR MLIRPass MhloDialect - ChloDialect MhloToLinalg + StablehloBase TorchMLIRTorchDialect ) diff --git a/lib/Conversion/TorchToMhlo/Linear.cpp b/lib/Conversion/TorchToMhlo/Linear.cpp index 0bb2e388bcffd..da54955c479fb 100644 --- a/lib/Conversion/TorchToMhlo/Linear.cpp +++ b/lib/Conversion/TorchToMhlo/Linear.cpp @@ -12,10 +12,10 @@ #include "../PassDetail.h" #include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" -#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "stablehlo/dialect/ChloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" diff --git a/lib/Conversion/TorchToMhlo/Pooling.cpp b/lib/Conversion/TorchToMhlo/Pooling.cpp index 3c74e23dc7b94..6bb07579bb097 100644 --- a/lib/Conversion/TorchToMhlo/Pooling.cpp +++ b/lib/Conversion/TorchToMhlo/Pooling.cpp @@ -12,10 +12,10 @@ #include "../PassDetail.h" #include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" -#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "stablehlo/dialect/ChloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" diff --git a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp index 3ac702b5b030e..741854c84102a 100644 --- a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp +++ b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp @@ -11,13 +11,13 @@ #include "../PassDetail.h" #include "./PopulatePatterns.h" -#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/ChloOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" diff --git a/test/Conversion/TorchToMhlo/elementwise.mlir b/test/Conversion/TorchToMhlo/elementwise.mlir index aaf5720aa10e2..ae41c3fd65dcd 100644 --- a/test/Conversion/TorchToMhlo/elementwise.mlir +++ b/test/Conversion/TorchToMhlo/elementwise.mlir @@ -372,7 +372,7 @@ func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor -// CHECK: %[[T5:.*]] = chlo.broadcast_compare %[[T0]], %[[T4]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_compare %[[T0]], %[[T4]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[T6]] : !torch.vtensor<[?,?],i1> func.func @torch.aten.gt.scalar(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { @@ -387,7 +387,7 @@ func.func @torch.aten.gt.scalar(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor<64xf32>) -> tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> func.func @torch.aten.gt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { @@ -401,7 +401,7 @@ func.func @torch.aten.gt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch. // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor<64xf32>) -> tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> func.func @torch.aten.lt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { @@ -415,7 +415,7 @@ func.func @torch.aten.lt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch. // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor<64xf32>) -> tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> func.func @torch.aten.eq.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { @@ -429,7 +429,7 @@ func.func @torch.aten.eq.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch. // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor<64xf32>) -> tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { @@ -553,7 +553,7 @@ func.func @torch.aten.divscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 // CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor -// CHECK: %[[T5:.*]] = chlo.broadcast_compare %[[T0]], %[[T4]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_compare %[[T0]], %[[T4]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[T6]] : !torch.vtensor<[?,?],i1> func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,?],i1> {