From dda3aa8ff336d64f39de1d8e69141c7ff16396ca Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Thu, 29 Jun 2023 10:37:13 +0800 Subject: [PATCH] [Torch Dialect] add more scalar op folders (#2265) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 3 ++ lib/Dialect/Torch/IR/TorchOps.cpp | 36 +++++++++++++++++++ .../jit_ir/build_tools/torch_ods_gen.py | 6 ++-- test/Dialect/Torch/canonicalize.mlir | 29 +++++++++++++++ 4 files changed, 71 insertions(+), 3 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 5af833a911c93..a4b90d69e8e01 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10554,6 +10554,7 @@ def Torch_AtenAddFloatIntOp : Torch_Op<"aten.add.float_int", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [ @@ -10603,6 +10604,7 @@ def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenDivFloatOp : Torch_Op<"aten.div.float", [ @@ -10651,6 +10653,7 @@ def Torch_AtenNegFloatOp : Torch_Op<"aten.neg.float", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } def Torch_AtenEqFloatOp : Torch_Op<"aten.eq.float", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e6e62b0baec74..84e6bd3654e86 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2319,6 +2319,15 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenMulFloatOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenMulFloatOp::fold(FoldAdaptor adaptor) { + return atenBinaryFloatOperatorFoldHelper( + adaptor.getOperands(), [](double a, double b) { return a * b; }); +} + //===----------------------------------------------------------------------===// // AtenSubFloatOp //===----------------------------------------------------------------------===// @@ -2381,6 +2390,18 @@ OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) { [](double a, double b) -> double { return a / b; }); } +//===----------------------------------------------------------------------===// +// AtenAddFloatIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenAddFloatIntOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getA() || !adaptor.getB()) { + return nullptr; + } + return atenBinaryFloatOperatorFoldHelper( + adaptor.getOperands(), [](double a, double b) { return a + b; }); +} + //===----------------------------------------------------------------------===// // AtenPowIntFloatOp //===----------------------------------------------------------------------===// @@ -2421,6 +2442,21 @@ OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenNegFloatOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenNegFloatOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getA()) { + return nullptr; + } + auto value = adaptor.getA().dyn_cast_or_null(); + if (!value) { + return nullptr; + } + return getF64FloatAttr(getContext(), -value.getValue().convertToDouble()); +} + //===----------------------------------------------------------------------===// // AtenSqrtIntOp //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 6637f1a8c43fe..7646fda6cf47a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -633,11 +633,11 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::div.int : (int, int) -> (float)", has_folder=True) emit("aten::neg.int : (int) -> (int)", has_folder=True) emit("aten::log.int : (int) -> (float)") - emit("aten::add.float_int : (float, int) -> (float)") + emit("aten::add.float_int : (float, int) -> (float)", has_folder=True) emit("aten::sub.float : (float, float) -> (float)", has_folder=True) - emit("aten::mul.float : (float, float) -> (float)") + emit("aten::mul.float : (float, float) -> (float)", has_folder=True) emit("aten::div.float : (float, float) -> (float)", has_folder=True) - emit("aten::neg.float : (float) -> (float)") + emit("aten::neg.float : (float) -> (float)", has_folder=True) emit("aten::eq.float : (float, float) -> (bool)", has_folder=True) emit("aten::gt.float : (float, float) -> (bool)", has_folder=True) emit("aten::ge.float : (float, float) -> (bool)", has_folder=True) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 0cd59fd077285..16382f52ec8ae 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1036,6 +1036,16 @@ func.func @torch.aten.add.int() -> !torch.int { return %ret : !torch.int } +// CHECK-LABEL: func.func @torch.aten.add.float_int() -> !torch.float { +// CHECK: %[[CST9:.*]] = torch.constant.float 9.000000e+00 +// CHECK: return %[[CST9]] : !torch.float +func.func @torch.aten.add.float_int() -> !torch.float { + %cst4 = torch.constant.float 4.0 + %cst5 = torch.constant.int 5 + %ret = torch.aten.add.float_int %cst4, %cst5: !torch.float, !torch.int -> !torch.float + return %ret : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.sub.int() -> !torch.int { // CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: return %[[CST1]] : !torch.int @@ -1056,6 +1066,25 @@ func.func @torch.aten.mul.int() -> !torch.int { return %ret : !torch.int } +// CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float { +// CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01 +// CHECK: return %[[CST30]] : !torch.float +func.func @torch.aten.mul.float() -> !torch.float { + %cst6 = torch.constant.float 6.0 + %cst5 = torch.constant.float 5.0 + %ret = torch.aten.mul.float %cst6, %cst5: !torch.float, !torch.float -> !torch.float + return %ret : !torch.float +} + +// CHECK-LABEL: func.func @torch.aten.neg.float() -> !torch.float { +// CHECK: %[[CST_6:.*]] = torch.constant.float -6.000000e+00 +// CHECK: return %[[CST_6]] : !torch.float +func.func @torch.aten.neg.float() -> !torch.float { + %cst6 = torch.constant.float 6.0 + %ret = torch.aten.neg.float %cst6: !torch.float -> !torch.float + return %ret : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.mul.int$with_zero() -> !torch.int { // CHECK: %[[CST0:.*]] = torch.constant.int 0 // CHECK: return %[[CST0]] : !torch.int