diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index c121877bbc3e..e44518509fbc 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9075,6 +9075,519 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" " return %10 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bincount\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.int {\n" +" %int7 = torch.constant.int 7\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %5 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union, %arg4: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.__contains__.int_list %3, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %7 = torch.aten.__contains__.int_list %6, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = torch.aten.__contains__.int_list %9, %2#1 : !torch.list, !torch.int -> !torch.bool\n" +" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.eq.int %1#1, %2#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %13 = torch.aten.eq.int %0#1, %2#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %13 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %14 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %15 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %16 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%14, %15) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %16 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lerp.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.__contains__.int_list %3, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" %19 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %20 = torch.aten.__not__ %19 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = torch.aten.__contains__.int_list %7, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %9 = torch.aten.__not__ %8 : !torch.bool -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.bool) {\n" +" %19 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %20 = torch.aten.__not__ %19 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %2#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.bool) {\n" +" %19 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %20 = torch.aten.__not__ %19 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %14 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %15 = torch.aten.eq.int %1#1, %2#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %15 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %16 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %17 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %18 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%16, %17) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %18 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.__contains__.int_list %3, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %7 = torch.aten.__contains__.int_list %6, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = torch.aten.__contains__.int_list %9, %2#1 : !torch.list, !torch.int -> !torch.bool\n" +" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.eq.int %1#1, %2#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %13 = torch.aten.eq.int %0#1, %2#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %13 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %14 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %15 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %16 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%14, %15) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %16 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.__contains__.int_list %3, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" %20 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %21 = torch.aten.__not__ %20 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %21 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = torch.aten.__contains__.int_list %7, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %9 = torch.aten.__not__ %8 : !torch.bool -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.bool) {\n" +" %20 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %21 = torch.aten.__not__ %20 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %21 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %2#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.bool) {\n" +" %20 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %21 = torch.aten.__not__ %20 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %21 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %14 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %15 = torch.aten.eq.int %1#1, %2#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %15 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %16 = torch.aten.eq.int %0#1, %2#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %16 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %17 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %18 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %19 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%17, %18) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %19 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %4 = torch.prim.ListConstruct %0#1, %3 : (!torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mul.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fmod.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %false = torch.constant.bool false\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %9 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %10 = torch.aten.__not__ %9 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %10 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %9 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %9 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %7 = torch.prim.ListConstruct %0#1, %6 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%5, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %8 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union, %arg4: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.aten.__contains__.int_list %2, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = torch.aten.__contains__.int_list %5, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.self\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %int4 = torch.constant.int 4\n" +" %false = torch.constant.bool false\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %5 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" }\n" +" return %3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarOther\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarSelf\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %none, %0#0 : (!torch.none, !torch.int) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %2, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" "}\n" ""; // clang-format on diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 550f0de5aa08..f5af3e67aa44 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -662,16 +662,6 @@ void TypeAnalysis::visitOperation(Operation *op, return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); } - // Dtype is always si64. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - // Dtype is always float32, except for bfloat16, float64 and nullptr after // promotion and assuming possible-zero rank. if (isa(op)) { @@ -689,80 +679,11 @@ void TypeAnalysis::visitOperation(Operation *op, return; } - // Promote three dtypes. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( - op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue(), - &operands[2]->getValue()}); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - if (auto linear = llvm::dyn_cast(op)) { visitAtenLinearOp(linear, operands); return; } - // Promote LHS with scalar RHS. - if (isa(op)) { - auto lhs = operands[0]->getValue(); - Value scalar = op->getOperand(1); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultDType(&lhs, scalar.getType()); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote 2nd and 3rd operands. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultType( - op->getContext(), {&operands[1]->getValue(), &operands[2]->getValue()}, - getRankIsNonZeroArray(op->getOperands().slice(1, 2))); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote 2nd and 3rd operands. - if (isa(op)) { - Value lhsScalar = op->getOperand(1); - Value rhsScalar = op->getOperand(2); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getDefaultDtypeForTorchScalar(getPromotedResultScalarType( - {lhsScalar.getType(), rhsScalar.getType()})); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote 2nd and 3rd operands. - if (isa(op)) { - auto lhs = operands[1]->getValue(); - Value scalar = op->getOperand(2); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultDType(&lhs, scalar.getType()); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote 2nd and 3rd operands. - if (isa(op)) { - auto rhs = operands[2]->getValue(); - Value scalar = op->getOperand(1); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultDType(&rhs, scalar.getType()); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - // 2 results take dtype from first operand. if (isa(op)) { auto self = operands[0]->getValue(); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 3acfae8df5d5..414375a84bfb 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -1794,6 +1794,280 @@ def aten〇convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dt dtypes = [input_dtype, weight_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_tensors_with_the_same_dtype( + num_of_tensors=2, + error_types={torch.bool, torch.bfloat16, torch.float16, torch.float32, torch.float64, + torch.complex64, torch.complex128}) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.bfloat16, torch.float16, torch.float32, torch.float64, + torch.complex64, torch.complex128})) +def aten〇bincount〡dtype(self_rank_dtype: Tuple[int, int], weights_rank_dtype: Optional[Tuple[int, int]] = None, minlength: int = 0) -> int: + self_rank, self_dtype = self_rank_dtype + assert is_integer_dtype(self_dtype) and self_dtype != torch.bool + if weights_rank_dtype is None: + return torch.int64 + return torch.float64 + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)], error_types={torch.bool, torch.float16}) + + # Different width + [ErrorInvocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 4, dtype=torch.float64), + TensorOfShape(4, 3, dtype=torch.float32)), + # Different type + ErrorInvocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.int32)), + ErrorInvocation(TensorOfShape(3, 3, dtype=torch.int32), + TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float32))]) +def aten〇addmm〡dtype(self_rank_dtype: Tuple[int, int], mat1_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + mat1_rank, mat1_dtype = mat1_rank_dtype + mat2_rank, mat2_dtype = mat2_rank_dtype + + assert self_dtype not in [torch.bool, torch.float16] + assert mat1_dtype not in [torch.bool, torch.float16] + assert mat2_dtype not in [torch.bool, torch.float16] + assert mat1_dtype == mat2_dtype + assert self_dtype == mat2_dtype + ranks: List[Optional[int]] = [self_rank, mat1_rank, mat2_rank] + dtypes = [self_dtype, mat1_dtype, mat2_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)], error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.float16}) + + # Different width + [ErrorInvocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 4, dtype=torch.float64), + TensorOfShape(4, 3, dtype=torch.float32)), + # Different type + ErrorInvocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.int32)), + ErrorInvocation(TensorOfShape(3, 3, dtype=torch.int32), + TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float32))]) +def aten〇lerp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + end_rank, end_dtype = end_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + + assert self_dtype not in [torch.bool, torch.float16] and not is_integer_dtype(self_dtype) + assert end_dtype not in [torch.bool, torch.float16] and not is_integer_dtype(end_dtype) + assert weight_dtype not in [torch.bool, torch.float16] and not is_integer_dtype(weight_dtype) + assert end_dtype == weight_dtype + ranks: List[Optional[int]] = [self_rank, end_rank, weight_rank] + dtypes = [self_dtype, end_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)], error_types={torch.bool, torch.float16}) + + # Different width + [ErrorInvocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 4, dtype=torch.float64), + TensorOfShape(4, 3, dtype=torch.float32)), + # Different type + ErrorInvocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.int32)), + ErrorInvocation(TensorOfShape(3, 3, dtype=torch.int32), + TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float32))]) +def aten〇addcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + tensor1_rank, tensor1_dtype = tensor1_rank_dtype + tensor2_rank, tensor2_dtype = tensor2_rank_dtype + + assert self_dtype not in [torch.bool, torch.float16] + assert tensor1_dtype not in [torch.bool, torch.float16] + assert tensor2_dtype not in [torch.bool, torch.float16] + assert tensor1_dtype == tensor2_dtype + assert self_dtype == tensor2_dtype + ranks: List[Optional[int]] = [self_rank, tensor1_rank, tensor2_rank] + dtypes = [self_dtype, tensor1_dtype, tensor2_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)], error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.float16}) + + # Different width + [ErrorInvocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 4, dtype=torch.float64), + TensorOfShape(4, 3, dtype=torch.float32)), + # Different type + ErrorInvocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.int32)), + ErrorInvocation(TensorOfShape(3, 3, dtype=torch.int32), + TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float32))]) +def aten〇addcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + tensor1_rank, tensor1_dtype = tensor1_rank_dtype + tensor2_rank, tensor2_dtype = tensor2_rank_dtype + + assert self_dtype not in [torch.bool, torch.float16] and not is_integer_dtype(self_dtype) + assert tensor1_dtype not in [torch.bool, torch.float16] and not is_integer_dtype(tensor1_dtype) + assert tensor2_dtype not in [torch.bool, torch.float16] and not is_integer_dtype(tensor2_dtype) + assert tensor1_dtype == tensor2_dtype + assert self_dtype == tensor2_dtype + ranks: List[Optional[int]] = [self_rank, tensor1_rank, tensor2_rank] + dtypes = [self_dtype, tensor1_dtype, tensor2_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇add〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}, other=1.0)) +def aten〇sub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇mul〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇div〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + promoted_dtype = promote_dtypes(ranks, dtypes) + if is_integer_dtype(promoted_dtype): + return torch.float32 + return promoted_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0)) +def aten〇fmod〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_complex_dtype(self_dtype) + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0)) +def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_complex_dtype(self_dtype) + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1.0)) +def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(exponent)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.float16, torch.complex64, torch.complex128}, + negative_slope=1) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.float16, torch.complex64, torch.complex128}, + negative_slope=1.0)) +def aten〇leaky_relu〡dtype(self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float] = 0.01) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_complex_dtype(self_dtype) and not is_integer_dtype(self_dtype) and self_dtype != torch.float16 + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(negative_slope)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0)) +def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_complex_dtype(self_dtype) + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)], error_types={torch.bool, torch.float16}) + + [ErrorInvocation(TensorOfShape( + 1, 1, 1, dtype=torch.float64), TensorOfShape(1, 1, 1, dtype=torch.int16), TensorOfShape(1, 1, 1, dtype=torch.int32)), + ErrorInvocation( + TensorOfShape(1, 1, 1, dtype=torch.float64), TensorOfShape(1, 1, 1, dtype=torch.int64), TensorOfShape(1, 1, 1, dtype=torch.float16)), + ErrorInvocation( + TensorOfShape(1, 1, 1, dtype=torch.float64), TensorOfShape(1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, dtype=torch.int64)), + ErrorInvocation( + TensorOfShape(1, 1, 1, dtype=torch.float64), TensorOfShape(1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, dtype=torch.float16))]) +def aten〇baddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, alpha: Union[int, float] = 1) -> int: + batch1_rank, batch1_dtype = batch1_rank_dtype + batch2_rank, batch2_dtype = batch2_rank_dtype + assert batch1_dtype not in [torch.bool, torch.float16] + assert batch2_dtype not in [torch.bool, torch.float16] + assert batch1_dtype == batch2_dtype + ranks: List[Optional[int]] = [batch1_rank, batch2_rank] + dtypes = [batch1_dtype, batch2_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function([ + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int16), NonZeroDTensorWithDtype(torch.int32)), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int64), NonZeroDTensorWithDtype(torch.float16)), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.int64)), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.bfloat16), NonZeroDTensorWithDtype(torch.float16))]) +def aten〇where〇self〡dtype(condition_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function([Invocation(NonZeroDTensorWithDtype(torch.bool), 0, 0), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0, 0.0), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, 0), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, 0.0)]) +def aten〇where〇Scalar〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float], other: Union[int, float]) -> int: + if is_integer_dtype(get_dtype_of_scalar(self)) and is_integer_dtype(get_dtype_of_scalar(other)): + return torch.int64 + return torch.float32 + +@check_dtype_function([Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int16), 0), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int64), 0.0), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float16), 0), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float64), 0.0)]) +def aten〇where〇ScalarOther〡dtype(condition_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function([Invocation(NonZeroDTensorWithDtype(torch.bool), 0, NonZeroDTensorWithDtype(torch.int16)), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, NonZeroDTensorWithDtype(torch.int64)), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0, NonZeroDTensorWithDtype(torch.float16)), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, NonZeroDTensorWithDtype(torch.float64))]) +def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [None, other_rank] + dtypes = [get_dtype_of_scalar(self), other_dtype] + return promote_dtypes(ranks, dtypes) + # ============================================================================== # Main # ==============================================================================