Skip to content

Commit

Permalink
Add dtype functions for comparison ops (#1806)
Browse files Browse the repository at this point in the history
This commit adds dtype functions for comparison ops that always return
a tensor of dtype `i1`.
  • Loading branch information
ramiro050 authored Jan 16, 2023
1 parent 5b77c15 commit 8cae5ba
Show file tree
Hide file tree
Showing 4 changed files with 370 additions and 52 deletions.
280 changes: 280 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4164,6 +4164,226 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" } : (!torch.int, !torch.bool) -> ()\n"
" return %5 : !torch.list<int>\n"
" }\n"
" func.func @__torch__.torch.jit._shape_functions.movedim(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int0 = torch.constant.int 0\n"
" %int-1 = torch.constant.int -1\n"
" %true = torch.constant.bool true\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.le.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
" torch.prim.If.yield %arg0 : !torch.list<int>\n"
" } else {\n"
" %3 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %4 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %5 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %5, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %23 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %24 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %25 = torch.prim.If %24 -> (!torch.int) {\n"
" torch.prim.If.yield %int1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %0 : !torch.int\n"
" }\n"
" %26 = torch.aten.neg.int %25 : !torch.int -> !torch.int\n"
" %27 = torch.aten.sub.int %25, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %28 = torch.aten.lt.int %23, %26 : !torch.int, !torch.int -> !torch.bool\n"
" %29 = torch.prim.If %28 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %45 = torch.aten.gt.int %23, %27 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %45 : !torch.bool\n"
" }\n"
" %30 = torch.aten.__not__ %29 : !torch.bool -> !torch.bool\n"
" torch.prim.If %30 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %31 = torch.aten.lt.int %23, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %32 = torch.prim.If %31 -> (!torch.int) {\n"
" %45 = torch.aten.add.int %23, %25 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %45 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %23 : !torch.int\n"
" }\n"
" %33 = torch.aten.append.t %3, %32 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" %34 = torch.aten.__getitem__.t %arg2, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %35 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %36 = torch.prim.If %35 -> (!torch.int) {\n"
" torch.prim.If.yield %int1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %0 : !torch.int\n"
" }\n"
" %37 = torch.aten.neg.int %36 : !torch.int -> !torch.int\n"
" %38 = torch.aten.sub.int %36, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %39 = torch.aten.lt.int %34, %37 : !torch.int, !torch.int -> !torch.bool\n"
" %40 = torch.prim.If %39 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %45 = torch.aten.gt.int %34, %38 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %45 : !torch.bool\n"
" }\n"
" %41 = torch.aten.__not__ %40 : !torch.bool -> !torch.bool\n"
" torch.prim.If %41 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %42 = torch.aten.lt.int %34, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %43 = torch.prim.If %42 -> (!torch.int) {\n"
" %45 = torch.aten.add.int %34, %36 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %45 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %34 : !torch.int\n"
" }\n"
" %44 = torch.aten.append.t %4, %43 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %6 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" torch.prim.Loop %0, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %23 = torch.aten.append.t %6, %int-1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %7 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" torch.prim.Loop %0, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %23 = torch.aten.append.t %7, %arg3 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %8 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" torch.prim.Loop %0, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %23 = torch.aten.append.t %8, %arg3 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %9 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %9, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %23 = torch.aten.__getitem__.t %3, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %24 = torch.aten.__getitem__.t %4, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %25 = torch.aten._set_item.t %6, %24, %23 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
" %26 = torch.aten.__getitem__.t %3, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %27 = torch.aten._set_item.t %7, %26, %int-1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
" %28 = torch.aten.__getitem__.t %4, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %29 = torch.aten._set_item.t %8, %28, %int-1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %10 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %11 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %12 = torch.aten.len.t %7 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %12, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %23 = torch.aten.__getitem__.t %7, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %24 = torch.aten.ne.int %23, %int-1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %24 -> () {\n"
" %25 = torch.aten.append.t %10, %23 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %13 = torch.aten.len.t %8 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %13, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %23 = torch.aten.__getitem__.t %8, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %24 = torch.aten.ne.int %23, %int-1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %24 -> () {\n"
" %25 = torch.aten.append.t %11, %23 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %14 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %15 = torch.aten.sub.int %0, %14 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.Loop %15, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %23 = torch.aten.__getitem__.t %10, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %24 = torch.aten.__getitem__.t %11, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %25 = torch.aten._set_item.t %6, %24, %23 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %16 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %17 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int\n"
" %18 = torch.aten.eq.int %16, %17 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %18 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %19 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int\n"
" %20 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %21 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" torch.prim.Loop %19, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %23 = torch.aten.__getitem__.t %6, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %24 = torch.aten.le.int %19, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %25 = torch.prim.If %24 -> (!torch.int) {\n"
" torch.prim.If.yield %int1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %19 : !torch.int\n"
" }\n"
" %26 = torch.aten.neg.int %25 : !torch.int -> !torch.int\n"
" %27 = torch.aten.sub.int %25, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %28 = torch.aten.lt.int %23, %26 : !torch.int, !torch.int -> !torch.bool\n"
" %29 = torch.prim.If %28 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %36 = torch.aten.gt.int %23, %27 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %36 : !torch.bool\n"
" }\n"
" %30 = torch.aten.__not__ %29 : !torch.bool -> !torch.bool\n"
" torch.prim.If %30 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %31 = torch.aten.lt.int %23, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %32 = torch.prim.If %31 -> (!torch.int) {\n"
" %36 = torch.aten.add.int %23, %25 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %36 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %23 : !torch.int\n"
" }\n"
" %33 = torch.aten.append.t %20, %32 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" %34 = torch.aten.__getitem__.t %arg0, %32 : !torch.list<int>, !torch.int -> !torch.int\n"
" %35 = torch.aten.append.t %21, %34 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %22 = torch.aten.__range_length %int1, %19, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" torch.prim.Loop %22, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %23 = torch.aten.__derive_index %arg3, %int1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" torch.prim.Loop %23, %true, init() {\n"
" ^bb0(%arg4: !torch.int):\n"
" %24 = torch.aten.__getitem__.t %20, %23 : !torch.list<int>, !torch.int -> !torch.int\n"
" %25 = torch.aten.__getitem__.t %20, %arg4 : !torch.list<int>, !torch.int -> !torch.int\n"
" %26 = torch.aten.ne.int %24, %25 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %26 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" torch.prim.If.yield %21 : !torch.list<int>\n"
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @__torch__.torch.jit._shape_functions.view(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: invalid shape\"\n"
" %false = torch.constant.bool false\n"
Expand Down Expand Up @@ -7130,6 +7350,66 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.any\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.eq.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.gt.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.logical_and\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.logical_not\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.logical_or\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.logical_xor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.lt.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.union<float, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0 = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list<optional<int>>\n"
Expand Down
12 changes: 0 additions & 12 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,18 +699,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}

// Dtype is always i1.
if (isa<AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenNeScalarOp, AtenAnyOp, AtenAllOp, AtenEqTensorOp,
AtenGtTensorOp, AtenLtTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = IntegerType::get(op->getContext(), 1);
incorporateKnowledge(op->getResult(0), knowledge);
return;
}

// Dtype is always si64.
if (isa<AtenBincountOp>(op)) {
auto knowledge =
Expand Down
Loading

0 comments on commit 8cae5ba

Please sign in to comment.