diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index fd85f6c8d220..ac3ae738edf9 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -288,7 +288,7 @@ function test_in_tree() { python -m e2e_testing.main --config=lazy_tensor_core -v echo ":::: Run TorchDynamo e2e integration tests" - python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic RandnLikeModule_basic Matmul_dot + python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic RandnLikeModule_basic } function setup_venv() { diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 5be3b6c116ad..d5c7e9b60121 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -96,8 +96,6 @@ # Dtype function transition failures "MobilenetV3Module_basic", - "ResNet18Module_basic", - "ResNet18StaticModule_basic", } STABLEHLO_PASS_SET = { @@ -706,7 +704,8 @@ "FullLikeModuleInt2DStatic_basic", "FullModuleInt3D_basic", "FullModuleFloat2D_basic", - "RepeatModule_basic" + "RepeatModule_basic", + "ResNet18StaticModule_basic", } LTC_XFAIL_SET = { diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 2bef594c21b7..b7e6ce3ea981 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7690,20 +7690,28 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %3 : !torch.int\n" " }\n" " func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_float_dtypes() : () -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_float_dtypes() -> !torch.list {\n" " %int7 = torch.constant.int 7\n" " %int6 = torch.constant.int 6\n" " %int15 = torch.constant.int 15\n" " %int5 = torch.constant.int 5\n" " %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_complex_dtypes() : () -> !torch.list\n" " %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" " return %1 : !torch.bool\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_complex_dtypes() -> !torch.list {\n" " %int10 = torch.constant.int 10\n" " %int9 = torch.constant.int 9\n" " %0 = torch.prim.ListConstruct %int9, %int10 : (!torch.int, !torch.int) -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" -" return %1 : !torch.bool\n" +" return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -7777,6 +7785,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %2 : !torch.int\n" " }\n" " func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_integer_dtypes() -> !torch.list {\n" " %int4 = torch.constant.int 4\n" " %int3 = torch.constant.int 3\n" " %int2 = torch.constant.int 2\n" @@ -7784,8 +7797,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int0 = torch.constant.int 0\n" " %int11 = torch.constant.int 11\n" " %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" -" return %1 : !torch.bool\n" +" return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" @@ -7837,6 +7849,532 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.abs\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %3:2 = torch.prim.If %2 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int7 : !torch.bool, !torch.int\n" +" } else {\n" +" %5 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %6:2 = torch.prim.If %5 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int6 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %6#0, %6#1 : !torch.bool, !torch.int\n" +" }\n" +" %4 = torch.prim.If %3#0 -> (!torch.int) {\n" +" torch.prim.If.yield %3#1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.batch_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli\"(%arg0: !torch.tuple, %arg1: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_not\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.broadcast_to\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ceil\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp_min\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clone\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.constant_pad_nd\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.copy\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cpu\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cumsum\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.dropout\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.expand_as\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.expand\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fill.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fill.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.flatten.using_ints\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.flip\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.floor\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gather\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gelu_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.str) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list>, %arg1: !torch.list) -> !torch.int {\n" +" %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list>, !torch.list) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gelu\"(%arg0: !torch.tuple, %arg1: !torch.str) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardsigmoid\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardswish\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %int0, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %2 = torch.aten.__contains__.int_list %1, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index_put.hacked_twin\"(%arg0: !torch.tuple, %arg1: !torch.list>, %arg2: !torch.tuple, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._index_put_impl\"(%arg0: !torch.tuple, %arg1: !torch.list>>, %arg2: !torch.tuple, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index_put\"(%arg0: !torch.tuple, %arg1: !torch.list>>, %arg2: !torch.tuple, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index_select\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index.Tensor_hacked_twin\"(%arg0: !torch.tuple, %arg1: !torch.list>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.list>>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.layer_norm\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" return %arg3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill_.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_select\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mish\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.narrow\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.neg\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.numpy_T\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pad\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.str, %arg3: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.permute\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.prelu\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.relu\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.repeat\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._reshape_alias\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.reshape\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.resize_\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.roll\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.round\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.scatter_reduce.two\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.select.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.select_scatter\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.silu\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.slice_scatter\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.slice.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" return %arg3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.square\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.squeeze.dim\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.squeeze\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tanh_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.threshold\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.t\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.prim_Device\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.transpose.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.triu\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.uniform\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._unsafe_view\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.unsqueeze\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.view\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.zero\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.zero_\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prim.abs.Scalar\"(%arg0: !torch.union) -> !torch.int {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union) -> !torch.int {\n" +" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union -> !torch.tensor\n" +" %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.optional>, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tuple) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" @@ -7853,10 +8391,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %6 : !torch.int\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list>, %arg1: !torch.list) -> !torch.int {\n" -" %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list>, !torch.list) -> !torch.int\n" -" return %0 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -7964,11 +8498,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union) -> !torch.int {\n" -" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union -> !torch.tensor\n" -" %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n" -" return %1 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index bbef876fd1a7..301ab7bab59f 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -613,39 +613,6 @@ void TypeAnalysis::visitOperation(Operation *op, return; } - // Take dtype from first operand. - if (isa(op)) { - return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); - } - // Dtype is always float32, except for bfloat16, float64 and nullptr after // promotion and assuming possible-zero rank. if (isa(op)) { @@ -1365,8 +1332,7 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) { // the right thing forthose ops. // static bool allowsTypeRefinementOrIsSafeToRefine(Operation *op) { - return op->hasTrait() || - isa(op); + return op->hasTrait(); } // Some operations have extra verification logic regarding the relationship diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 04bbe0220e71..25403890e85b 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -167,12 +167,14 @@ FailureOr Torch::adjustFunctionArg( return b.create(loc, desiredType, operand).getResult(); } - // !torch.union is the type used for `Scalar` inputs. At - // compile time, such inputs will usually be resolved to an `int` or a `float` - // so we need to derefine to match the library function signature. + // !torch.union or !torch.union is the type used + // for (optional) `Scalar` inputs. At compile time, such inputs will usually + // be resolved to an `int` or a `float` so we need to derefine to match the + // library function signature. if (auto unionType = desiredType.dyn_cast()) { if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) { - return containedType.isa(); + return containedType + .isa(); })) return b.create(loc, desiredType, operand).getResult(); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 4ee73266bdcd..162146b3e152 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -12,7 +12,7 @@ import torch.jit._shape_functions as upstream_shape_functions from .testing_framework import Invocation, ErrorInvocation, TensorOfShape, LongTensorOfShape, NonZeroDTensorWithDtype, ZeroDTensorWithDtype, check_shape_function, check_dtype_function -from .library_generator import generate_library, not_present_in_registry, promote_dtypes, get_dtype_of_scalar, is_integer_dtype, is_float_dtype, is_complex_dtype, get_priority_of_dtype +from .library_generator import generate_library, not_present_in_registry, promote_dtypes, get_dtype_of_scalar, is_integer_dtype, is_float_dtype, is_complex_dtype, get_priority_of_dtype, all_integer_dtypes, all_float_dtypes, all_complex_dtypes # ============================================================================== # Shape Functions @@ -1088,6 +1088,7 @@ def _check_tensors_with_the_same_dtype( return invocations def _check_two_tensor_op( + tensor_shapes: Optional[list[tuple[int]]] = None, input_error_types: Optional[set[int]] = None, output_error_types: Optional[set[int]] = None, **kwargs): """Generate invocations for basic two-tensor dtype functions. @@ -1101,6 +1102,10 @@ def _check_two_tensor_op( must be specified in `input_error_types` and `output_error_types` to ensure the invocations are error invocations. """ + if tensor_shapes is None: + tensor_shapes = [(1,), (1,)] + shape_1, shape_2 = tensor_shapes + if input_error_types is not None and output_error_types is not None: assert len(input_error_types.intersection(output_error_types)) == 0, \ "An invalid input type implies an invalid output type, " \ @@ -1138,18 +1143,16 @@ def check_two_tensors_with_one_varying_dtype_at_a_time(**kwargs): invocations = [] for type_ in _SORTED_TORCH_TYPES[constant_type_index:]: - tensor_1 = NonZeroDTensorWithDtype(type_) - tensor_2 = NonZeroDTensorWithDtype(constant_type) if input_error_types is not None and type_ in input_error_types: - invocations += [ErrorInvocation(tensor_1, tensor_2, **kwargs), - ErrorInvocation(tensor_2, tensor_1, **kwargs)] + invocation_type = ErrorInvocation else: - invocations += [Invocation(tensor_1, tensor_2, **kwargs), - Invocation(tensor_2, tensor_1, **kwargs)] + invocation_type = Invocation + invocations += [invocation_type(TensorOfShape(*shape_1, dtype=type_), TensorOfShape(*shape_2, dtype=constant_type), **kwargs), + invocation_type(TensorOfShape(*shape_1, dtype=constant_type), TensorOfShape(*shape_2, dtype=type_), **kwargs)] return invocations same_dtype_invocations = _check_tensors_with_the_same_dtype( - num_of_tensors=2, error_types=all_error_types, **kwargs) + tensor_shapes=tensor_shapes, error_types=all_error_types, **kwargs) varying_dtype_invocations = \ check_two_tensors_with_one_varying_dtype_at_a_time(**kwargs) @@ -1252,6 +1255,524 @@ def prims〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return self_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇abs〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.complex128: + return torch.float64 + elif self_dtype == torch.complex64: + return torch.float32 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2])) +def aten〇adaptive_avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) +def aten〇avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype( + tensor_shapes=[(2, 3, 5), (3,), (3,), (3,), (3,)], training=False, momentum=0.1, eps=1e-5, cudnn_enabled=True)) +def aten〇batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇bernoulli〡dtype(self_rank_dtype: Tuple[int, int], generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) +def aten〇bernoulli〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], p_rank_dtype: Tuple[int, int], generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇bitwise_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[2, 2])) +def aten〇broadcast_to〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇ceil〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, max=0)) +def aten〇clamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.bool: + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=0)) +def aten〇clamp_min〡dtype(self_rank_dtype: Tuple[int, int], min: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.bool: + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=-1, max=1)) +def aten〇clamp〡dtype(self_rank_dtype: Tuple[int, int], min: Optional[Union[int, float]] = None, max: Optional[Union[int, float]] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.bool: + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇clone〡dtype(self_rank_dtype: Tuple[int, int], memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, pad=[1, 1])) +def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[int], value: Union[int, float] = 0) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇contiguous〡dtype(self_rank_dtype: Tuple[int, int], memory_format: int = 0) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇copy〡dtype(self_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], non_blocking: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +# TODO: Cannot call .cpu() on meta tensor due to: Cannot copy out of meta tensor; no data! +def aten〇cpu〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float32)) +def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇detach〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, p=0.5, train=False)) +def aten〇dropout〡dtype(input_rank_dtype: Tuple[int, int], p: float, train: bool) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇expand_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[2, 2])) +def aten〇expand〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], implicit: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, value=0)) +def aten〇fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(1,), ()])) +def aten〇fill〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇flatten〇using_ints〡dtype(self_rank_dtype: Tuple[int, int], start_dim: int = 0, end_dim: int = -1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0])) +def aten〇flip〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇floor〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, 0, TensorOfShape(1, dtype=torch.int64))) +def aten〇gather〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], sparse_grad: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇gelu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], approximate: str = "none") -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [grad_output_rank, self_rank] + dtypes = [grad_output_dtype, self_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇gelu〡dtype(self_rank_dtype: Tuple[int, int], approximate: str = "none") -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇hardsigmoid〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇hardswish〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op(min_val=0.2, max_val=0.5)) +def aten〇hardtanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], min_val: Union[int, float], max_val: Union[int, float]) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + if is_integer_dtype(grad_output_dtype): + return torch.float32 + return grad_output_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.uint8, torch.bool})) +def aten〇hardtanh〡dtype(self_rank_dtype: Tuple[int, int], min_val: Union[int, float] = -1, max_val: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype not in [torch.uint8, torch.bool] + return self_dtype + +_index_put_invocations = [ + # same dtype + Invocation(TensorOfShape(3, dtype=dtype), [TensorOfShape(3, dtype=torch.int64)], TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES +] + [ + # different dtypes + Invocation(TensorOfShape(3, dtype=dtype), [TensorOfShape(3, dtype=torch.int64)], TensorOfShape(3, dtype=torch.float32)) for dtype in _SORTED_TORCH_TYPES +] + [ + # index dtype + Invocation(TensorOfShape(3, dtype=torch.float32), [TensorOfShape(3, dtype=dtype)], TensorOfShape(3, dtype=torch.float32)) for dtype in _SORTED_TORCH_TYPES +] +@check_dtype_function(_index_put_invocations) +def aten〇index_put〇hacked_twin〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Tuple[int, int]], values_rank_dtype: Tuple[int, int], accumulate: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_index_put_invocations) +def aten〇_index_put_impl〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]], values_rank_dtype: Tuple[int, int], accumulate: bool = False, unsafe: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_index_put_invocations) +def aten〇index_put〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]], values_rank_dtype: Tuple[int, int], accumulate: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, 0, TensorOfShape(1, dtype=torch.int64))) +def aten〇index_select〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, [TensorOfShape(1, dtype=torch.int64)])) +def aten〇index〇Tensor_hacked_twin〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Tuple[int, int]]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, [TensorOfShape(1, dtype=torch.int64)])) +def aten〇index〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={*all_integer_dtypes()}, normalized_shape=[1])) +def aten〇layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shape: List[int], weight_rank_dtype: Optional[Tuple[int, int]] = None, bias_rank_dtype: Optional[Tuple[int, int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enable: bool = True) -> int: + input_rank, input_dtype = input_rank_dtype + assert not is_integer_dtype(input_dtype) + return input_dtype + +@check_dtype_function(_check_two_tensor_op(negative_slope=0.1, self_is_result=False)) +def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float], self_is_result: bool) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [grad_output_rank, self_rank] + dtypes = [grad_output_dtype, self_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇lift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + _check_two_tensor_op(dim=0, input_dtype=torch.float32) + + _check_two_tensor_op(dim=0, input_dtype=torch.float64)) +def aten〇_log_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int], dim: int, input_dtype: int) -> int: + return input_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, TensorOfShape(1, dtype=torch.bool), 0)) +def aten〇masked_fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, TensorOfShape(1, dtype=torch.bool), 0)) +def aten〇masked_fill_〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, TensorOfShape(1, dtype=torch.bool), TensorOfShape(dtype=torch.float32))) +def aten〇masked_fill〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +# TODO: This op cannot be run on the Meta backend. We should have the test suite default on CPU when the Meta backend does not work. +def aten〇masked_select〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) +def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇mish〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, start=0, length=1)) +def aten〇narrow〡dtype(self_rank_dtype: Tuple[int, int], dim: int, start: int, length: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) +def aten〇neg〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇numpy_T〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, pad=[1, 1])) +def aten〇pad〡dtype(self_rank_dtype: Tuple[int, int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0])) +def aten〇permute〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇pow〇Tensor_Tensor〡dtype(self_rank_dtype: Tuple[int, int], exponent_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + exponent_rank, exponent_dtype = exponent_rank_dtype + ranks: List[Optional[int]] = [self_rank, exponent_rank] + dtypes = [self_dtype, exponent_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + if promoted_dtype == torch.bool: + return torch.int64 + return promoted_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=2) + + [ErrorInvocation(TensorOfShape(1, dtype=torch.float32), TensorOfShape(1, dtype=torch.float64)), + ErrorInvocation(TensorOfShape(1, dtype=torch.float64), TensorOfShape(1, dtype=torch.float32))]) +def aten〇prelu〡dtype(self_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert self_dtype == weight_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) +def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇relu〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, repeats=[1])) +def aten〇repeat〡dtype(self_rank_dtype: Tuple[int, int], repeats: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1])) +def aten〇_reshape_alias〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, shape=[1])) +def aten〇reshape〡dtype(self_rank_dtype: Tuple[int, int], shape: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) +def aten〇resize_〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, shifts=[0], dims=[0])) +def aten〇roll〡dtype(self_rank_dtype: Tuple[int, int], shifts: List[int], dims: List[int] = ()) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇round〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + [Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype), "sum") for dtype in _SORTED_TORCH_TYPES]) +def aten〇scatter_reduce〇two〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], reduce: str, include_self: bool = True) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, index=0)) +def aten〇select〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op(tensor_shapes=[(1, 1), (1,)], dim=0, index=0)) +def aten〇select_scatter〡dtype(self_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], dim: int, index: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇silu〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op(dim=0)) +def aten〇slice_scatter〡dtype(self_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇slice〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float64) + + [Invocation(TensorOfShape(1, dtype=torch.float32), TensorOfShape(1, dtype=torch.float64), dim=0, input_dtype=torch.float32), + Invocation(TensorOfShape(1, dtype=torch.float64), TensorOfShape(1, dtype=torch.float32), dim=0, input_dtype=torch.float32)]) +def aten〇_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int], dim: int, input_dtype: int) -> int: + return input_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇square〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.bool: + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇squeeze〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇squeeze〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇tanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int]) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + output_rank, output_dtype = output_rank_dtype + ranks: List[Optional[int]] = [grad_output_rank, output_rank] + dtypes = [grad_output_dtype, output_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, threshold=0, value=0)) +def aten〇threshold〡dtype(self_rank_dtype: Tuple[int, int], threshold: Union[int, float], value: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇t〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, device=torch.device("meta"))) +def aten〇to〇prim_Device〡dtype(self_rank_dtype: Tuple[int, int], device: Optional[device], dtype: Optional[int] = None, non_blocking: bool = False, copy: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)], dim0=0, dim1=1)) +def aten〇transpose〇int〡dtype(self_rank_dtype: Tuple[int, int], dim0: int, dim1: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)])) +def aten〇triu〡dtype(self_rank_dtype: Tuple[int, int], diagonal: int = 0) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0., to: float = 1., generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) +def aten〇_unsafe_view〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇unsqueeze〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 4, 8)], output_size=[4, 8], input_size=[1, 1, 2, 3])) +def aten〇upsample_nearest2d_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + return grad_output_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13])) +def aten〇upsample_nearest2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) +def aten〇view〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇zero〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇zero_〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function([Invocation(-1), Invocation(-1.0)]) +def prim〇abs〇Scalar〡dtype(a: Union[int, float]) -> int: + return get_dtype_of_scalar(a) + @check_dtype_function(_check_tensors_with_the_same_dtype( None, [(3,), (3, 4)], None, TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32)) + diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py index 18167a26c50a..9608ce6c639a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py @@ -14,15 +14,23 @@ from .registry import Registry +def all_integer_dtypes() -> List[int]: + return [torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] + def is_integer_dtype(dtype: int) -> bool: - return dtype in [torch.bool, torch.uint8, torch.int8, - torch.int16, torch.int32, torch.int64] + return dtype in all_integer_dtypes() + +def all_complex_dtypes() -> List[int]: + return [torch.complex64, torch.complex128] def is_complex_dtype(dtype: int) -> bool: - return dtype in [torch.complex64, torch.complex128] + return dtype in all_complex_dtypes() + +def all_float_dtypes() -> List[int]: + return [torch.float16, torch.bfloat16, torch.float32, torch.float64] def is_float_dtype(dtype: int) -> bool: - return dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64] + return dtype in all_float_dtypes() def get_priority_of_dtype(dtype: int) -> int: sorted_types = [ diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py index fd3354909031..e06a3f537d30 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py @@ -86,7 +86,7 @@ def _recursively_transform_tensor_args( o: Any, tensor_transformer: Callable[[TensorOfShape], Any]) -> Any: """Replace `TensorOfShape` with the result of `tensor_transformer`""" - if o is None or isinstance(o, (float, int)): + if o is None or isinstance(o, (float, int, str)): return o if isinstance(o, TensorOfShape): return tensor_transformer(o) diff --git a/python/torch_mlir_e2e_test/test_suite/__init__.py b/python/torch_mlir_e2e_test/test_suite/__init__.py index b0ea4dd8b770..2d584489842b 100644 --- a/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -8,9 +8,6 @@ # to the backend contract. COMMON_TORCH_MLIR_LOWERING_XFAILS = { "QuantizedMLP_basic", - "ResNet18Module_basic", - "ResNet18StaticModule_basic", - "MobilenetV3Module_basic", "ReduceMaxAlongDimUnsignedInt_basic", } diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index ae02a0339d8e..8af362cd0d64 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -8,109 +8,6 @@ // ----- -// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static( -// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>, -// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32> -// CHECK: %[[CAST2:.*]] = torch.tensor_static_info_cast %[[CAST:.*]] : !torch.vtensor<*,f32> to !torch.vtensor<*,f32> -// CHECK: torch.overwrite.tensor.contents %[[CAST2]] overwrites %[[STATIC_COPY:.*]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32> -func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { - %static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor - %static_copy = torch.copy.to_tensor %static_no_type : !torch.tensor - %dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor - torch.overwrite.tensor.contents %dynamic_no_type overwrites %static_copy : !torch.vtensor, !torch.tensor - %static_value_copy = torch.copy.to_vtensor %static_copy : !torch.vtensor - %result = torch.tensor_static_info_cast %static_value_copy : !torch.vtensor to !torch.vtensor<[2],f32> - return %result : !torch.vtensor<[2],f32> -} - -// ----- -// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$static_overwrites_dynamic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { -// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[2],f32> to !torch.vtensor<*,f32> -// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32> -// CHECK: %[[MUTABLE_COPY:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor<*,f32> -// CHECK: torch.overwrite.tensor.contents %[[ARG0_ERASED]] overwrites %[[MUTABLE_COPY]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32> -func.func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { - %static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor - %dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor - %dynamic_copy = torch.copy.to_tensor %dynamic_no_type : !torch.tensor - torch.overwrite.tensor.contents %static_no_type overwrites %dynamic_copy : !torch.vtensor, !torch.tensor - %dynamic_value_copy = torch.copy.to_vtensor %dynamic_copy : !torch.vtensor - %result = torch.tensor_static_info_cast %dynamic_value_copy : !torch.vtensor to !torch.vtensor<[?],f32> - return %result : !torch.vtensor<[?],f32> -} - -// ----- -// CHECK-LABEL: func.func @bf16_result_type( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> { -// CHECK: %[[SQRT:.*]] = torch.aten.sqrt %[[ARG0]] : !torch.vtensor<*,bf16> -> !torch.vtensor<[2],bf16> -// CHECK: return %[[SQRT]] : !torch.vtensor<[2],bf16> -func.func @bf16_result_type(%arg0: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> { - %1 = torch.aten.sqrt %arg0 : !torch.vtensor<*,bf16> -> !torch.vtensor<[2], bf16> - return %1 : !torch.vtensor<[2],bf16> -} - -// ----- -// CHECK-LABEL: func.func @propagate_scalar_type( -// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number { -// CHECK: %[[NUM:.*]] = torch.derefine %[[INT]] : !torch.int to !torch.number -// CHECK: %[[ABS:.*]] = torch.prim.abs.Scalar %[[INT]] : !torch.int -> !torch.int -// CHECK: %[[RET:.*]] = torch.derefine %[[ABS]] : !torch.int to !torch.number -// CHECK: return %[[RET]] : !torch.number -func.func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number { - %num = torch.derefine %arg0 : !torch.int to !torch.number - %1 = torch.prim.abs.Scalar %num: !torch.number -> !torch.number - return %1 : !torch.number -} - -// ----- -// CHECK-LABEL: func.func @prim.dtype( -// CHECK-SAME: %[[arg:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor { - -// CHECK: %[[zero:.*]] = torch.constant.int 0 -// CHECK: %[[false:.*]] = torch.constant.bool false - -// CHECK: %[[neg:.*]] = torch.aten.neg %[[arg]] : !torch.vtensor<*,bf16> -> !torch.vtensor<*,bf16> -// CHECK: %[[dtype0:.*]] = torch.prim.dtype %[[neg]] : !torch.vtensor<*,bf16> -> !torch.int -// CHECK: %[[device0:.*]] = torch.prim.device %[[neg]] : !torch.vtensor<*,bf16> -> !torch.Device -// CHECK: %[[tensor:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype0]], %[[device0]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16> - -// CHECK: %[[dtype1:.*]] = torch.prim.dtype %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.int -// CHECK: %[[device1:.*]] = torch.prim.device %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.Device -// CHECK: %[[result:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype1]], %[[device1]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16> - -// CHECK: %[[cast:.*]] = torch.tensor_static_info_cast %[[result]] : !torch.vtensor<*,bf16> to !torch.vtensor -// CHECK: return %[[cast]] : !torch.vtensor -// CHECK: } - -func.func @prim.dtype(%arg: !torch.vtensor<*,bf16>) -> !torch.vtensor<*,unk> { - %zero = torch.constant.int 0 - %false = torch.constant.bool false - - // Op that requires type refinement - %neg = torch.aten.neg %arg : !torch.vtensor<*,bf16> -> !torch.vtensor<*,unk> - - // Op whose processing requires type refinement on its source argument. - %dtype = torch.prim.dtype %neg : !torch.vtensor<*,unk> -> !torch.int - %device = torch.prim.device %neg : !torch.vtensor<*,unk> -> !torch.Device - - // Another op that requires type refinement - %result = torch.aten.tensor.int %zero, %dtype, %device, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk> - - // Repeat the above three ops a second time to ensure that the type refinement - // code works regardless of the number of alternating refinement+prim.dtype - // sequences. - %dtype2 = torch.prim.dtype %result : !torch.vtensor<*,unk> -> !torch.int - %device2 = torch.prim.device %result : !torch.vtensor<*,unk> -> !torch.Device - %result2 = torch.aten.tensor.int %zero, %dtype2, %device2, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk> - - return %result2 : !torch.vtensor<*,unk> -} - -// ----- - // Check that we don't crash on this input. // CHECK-LABEL: func.func @forward