Skip to content

Commit

Permalink
[Torch Dialect] add more scalar op folders (#2265)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu authored Jun 29, 2023
1 parent 8281935 commit 449cfb8
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 3 deletions.
3 changes: 3 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10554,6 +10554,7 @@ def Torch_AtenAddFloatIntOp : Torch_Op<"aten.add.float_int", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [
Expand Down Expand Up @@ -10603,6 +10604,7 @@ def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenDivFloatOp : Torch_Op<"aten.div.float", [
Expand Down Expand Up @@ -10651,6 +10653,7 @@ def Torch_AtenNegFloatOp : Torch_Op<"aten.neg.float", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenEqFloatOp : Torch_Op<"aten.eq.float", [
Expand Down
36 changes: 36 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2319,6 +2319,15 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenMulFloatOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenMulFloatOp::fold(FoldAdaptor adaptor) {
return atenBinaryFloatOperatorFoldHelper(
adaptor.getOperands(), [](double a, double b) { return a * b; });
}

//===----------------------------------------------------------------------===//
// AtenSubFloatOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2381,6 +2390,18 @@ OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) {
[](double a, double b) -> double { return a / b; });
}

//===----------------------------------------------------------------------===//
// AtenAddFloatIntOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenAddFloatIntOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA() || !adaptor.getB()) {
return nullptr;
}
return atenBinaryFloatOperatorFoldHelper(
adaptor.getOperands(), [](double a, double b) { return a + b; });
}

//===----------------------------------------------------------------------===//
// AtenPowIntFloatOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2421,6 +2442,21 @@ OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenNegFloatOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenNegFloatOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA()) {
return nullptr;
}
auto value = adaptor.getA().dyn_cast_or_null<FloatAttr>();
if (!value) {
return nullptr;
}
return getF64FloatAttr(getContext(), -value.getValue().convertToDouble());
}

//===----------------------------------------------------------------------===//
// AtenSqrtIntOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,11 +633,11 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::div.int : (int, int) -> (float)", has_folder=True)
emit("aten::neg.int : (int) -> (int)", has_folder=True)
emit("aten::log.int : (int) -> (float)")
emit("aten::add.float_int : (float, int) -> (float)")
emit("aten::add.float_int : (float, int) -> (float)", has_folder=True)
emit("aten::sub.float : (float, float) -> (float)", has_folder=True)
emit("aten::mul.float : (float, float) -> (float)")
emit("aten::mul.float : (float, float) -> (float)", has_folder=True)
emit("aten::div.float : (float, float) -> (float)", has_folder=True)
emit("aten::neg.float : (float) -> (float)")
emit("aten::neg.float : (float) -> (float)", has_folder=True)
emit("aten::eq.float : (float, float) -> (bool)", has_folder=True)
emit("aten::gt.float : (float, float) -> (bool)", has_folder=True)
emit("aten::ge.float : (float, float) -> (bool)", has_folder=True)
Expand Down
29 changes: 29 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,16 @@ func.func @torch.aten.add.int() -> !torch.int {
return %ret : !torch.int
}

// CHECK-LABEL: func.func @torch.aten.add.float_int() -> !torch.float {
// CHECK: %[[CST9:.*]] = torch.constant.float 9.000000e+00
// CHECK: return %[[CST9]] : !torch.float
func.func @torch.aten.add.float_int() -> !torch.float {
%cst4 = torch.constant.float 4.0
%cst5 = torch.constant.int 5
%ret = torch.aten.add.float_int %cst4, %cst5: !torch.float, !torch.int -> !torch.float
return %ret : !torch.float
}

// CHECK-LABEL: func.func @torch.aten.sub.int() -> !torch.int {
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: return %[[CST1]] : !torch.int
Expand All @@ -1056,6 +1066,25 @@ func.func @torch.aten.mul.int() -> !torch.int {
return %ret : !torch.int
}

// CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float {
// CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01
// CHECK: return %[[CST30]] : !torch.float
func.func @torch.aten.mul.float() -> !torch.float {
%cst6 = torch.constant.float 6.0
%cst5 = torch.constant.float 5.0
%ret = torch.aten.mul.float %cst6, %cst5: !torch.float, !torch.float -> !torch.float
return %ret : !torch.float
}

// CHECK-LABEL: func.func @torch.aten.neg.float() -> !torch.float {
// CHECK: %[[CST_6:.*]] = torch.constant.float -6.000000e+00
// CHECK: return %[[CST_6]] : !torch.float
func.func @torch.aten.neg.float() -> !torch.float {
%cst6 = torch.constant.float 6.0
%ret = torch.aten.neg.float %cst6: !torch.float -> !torch.float
return %ret : !torch.float
}

// CHECK-LABEL: func.func @torch.aten.mul.int$with_zero() -> !torch.int {
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: return %[[CST0]] : !torch.int
Expand Down

0 comments on commit 449cfb8

Please sign in to comment.