From 7b1c93f0b07253fba8856b88ed90e4a4a428aa3f Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Fri, 24 Mar 2023 19:35:24 +0000 Subject: [PATCH 1/5] Add dtype functions for "Arange ops" --- .../Transforms/AbstractInterpLibrary.cpp | 110 ++++++++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 62 ---------- .../build_tools/abstract_interp_lib_gen.py | 45 +++++++ 3 files changed, 155 insertions(+), 62 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index b7e6ce3ea981..05f5f9ac4761 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9321,6 +9321,116 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.TupleConstruct %0#1, %0#1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" " return %3 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange\"(%arg0: !torch.union, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int6 = torch.constant.int 6\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange.start\"(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int6 = torch.constant.int 6\n" +" %true = torch.constant.bool true\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%6) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" }\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange.start_step\"(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.union, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int6 = torch.constant.int 6\n" +" %true = torch.constant.bool true\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %8 : !torch.bool\n" +" }\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %8 : !torch.bool\n" +" }\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" }\n" +" torch.prim.If.yield %6 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" "}\n" ""; // clang-format on diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 301ab7bab59f..4a7df8d99a7c 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -429,12 +429,6 @@ class TypeAnalysis : public dataflow::SparseDataFlowAnalysis< void visitAtenLinearOp(AtenLinearOp op, ArrayRef operands); - void visitAtenArangeStartStepOp(AtenArangeStartStepOp op); - void visitAtenArangeStartOp(AtenArangeStartOp op); - void visitAtenArangeOp(AtenArangeOp op); - void visitAtenArangeLikeOpHelper(Operation *op, std::optional start, - Value end, std::optional step, - Value dtype); void visitReductionAlongAllDimsOp(Operation *op, Type dtype, ArrayRef operands); void visitReductionAlongDimIntListOp(Operation *op, Value dim, Value keepdim, @@ -649,19 +643,6 @@ void TypeAnalysis::visitOperation(Operation *op, return; } - if (auto arange = dyn_cast(op)) { - visitAtenArangeOp(arange); - return; - } - if (auto arangeStart = dyn_cast(op)) { - visitAtenArangeStartOp(arangeStart); - return; - } - if (auto arangeStartStep = dyn_cast(op)) { - visitAtenArangeStartStepOp(arangeStartStep); - return; - } - if (auto sum = dyn_cast(op)) { Type defaultDtype = operands[0]->getValue().dtype; // If the input dtype is bool, the result type should be i64. @@ -1021,49 +1002,6 @@ void TypeAnalysis::visitAtenEmbeddingBagOp(Operation *op) { return; } -// Arange like ops returns a 1-D tensor of size ceil(end - start). -void TypeAnalysis::visitAtenArangeLikeOpHelper(Operation *op, - std::optional start, - Value end, - std::optional step, - Value dtype) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - int64_t dtypeInt; - if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) { - knowledge.dtype = getTypeForDTypeInteger(op->getContext(), dtypeInt); - } else if (dtype.getType().isa()) { - // From torch/_torch_docs.py: - // If `dtype` is not given, infer the data type from the other input - // arguments. If any of `start`, `end`, or `step` are floating-point, the - // `dtype` is inferred to be the default dtype, see - // `torch.get_default_dtype`. Otherwise, the `dtype` is inferred to - // be `torch.int64` - if ((start.has_value() && (*start).getType().isa()) || - end.getType().isa() || - (step.has_value() && (*step).getType().isa())) { - // TODO: Should get the dtype from torch.get_default_dtype(). - // For now, use float32 which is the initial default dtype. - knowledge.dtype = Float32Type::get(op->getContext()); - } else - knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - } - incorporateKnowledge(op->getResult(0), knowledge); -} - -void TypeAnalysis::visitAtenArangeStartStepOp(AtenArangeStartStepOp op) { - visitAtenArangeLikeOpHelper(op, op.getStart(), op.getEnd(), op.getStep(), op.getDtype()); -} - -void TypeAnalysis::visitAtenArangeStartOp(AtenArangeStartOp op) { - visitAtenArangeLikeOpHelper(op, op.getStart(), op.getEnd(), {}, op.getDtype()); -} - -void TypeAnalysis::visitAtenArangeOp(AtenArangeOp op) { - visitAtenArangeLikeOpHelper(op, {}, op.getEnd(), {}, op.getDtype()); -} - void TypeAnalysis::visitReductionAlongAllDimsOp( Operation *op, Type dtype, ArrayRef operands) { auto knowledge = diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 162146b3e152..5675d890cb78 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -2566,6 +2566,51 @@ def aten〇native_batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r result_dtype = torch.float32 return input_dtype, input_dtype, result_dtype +@check_dtype_function([Invocation(end=0, dtype=None), # No floats + Invocation(end=0.0, dtype=None), # One float + ErrorInvocation(end=0, dtype=torch.complex64), # Dtype specified + Invocation(end=0, dtype=torch.float16), # Dtype specified + Invocation(end=0, dtype=torch.int16)]) # Dtype specified +def aten〇arange〡dtype(end: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is not None: + assert not is_complex_dtype(dtype) + return dtype + if is_float_dtype(get_dtype_of_scalar(end)): + return torch.float32 + return torch.int64 + +@check_dtype_function([Invocation(start=0, end=10, dtype=None), # No floats + Invocation(start=0.0, end=10, dtype=None), # One float + Invocation(start=0, end=10.0, dtype=None), # One float + ErrorInvocation(start=0, end=10, dtype=torch.complex64), # Dtype specified + Invocation(start=0, end=10, dtype=torch.float16), # Dtype specified + Invocation(start=0, end=10, dtype=torch.int16)]) # Dtype specified +def aten〇arange〇start〡dtype(start: Union[int, float], end: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is not None: + assert not is_complex_dtype(dtype) + return dtype + if is_float_dtype(get_dtype_of_scalar(start)) or \ + is_float_dtype(get_dtype_of_scalar(end)): + return torch.float32 + return torch.int64 + +@check_dtype_function([Invocation(start=0, end=10, step=1, dtype=None), # No floats + Invocation(start=0.0, end=10, step=1, dtype=None), # One float + Invocation(start=0, end=10.0, step=1, dtype=None), # One float + Invocation(start=0, end=10, step=1.0, dtype=None), # One float + ErrorInvocation(start=0, end=10, step=1, dtype=torch.complex64), # Dtype specified + Invocation(start=0, end=10, step=1, dtype=torch.float16), # Dtype specified + Invocation(start=0, end=10, step=1, dtype=torch.int16)]) # Dtype specified +def aten〇arange〇start_step〡dtype(start: Union[int, float], end: Union[int, float], step: Union[int, float] = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is not None: + assert not is_complex_dtype(dtype) + return dtype + if is_float_dtype(get_dtype_of_scalar(start)) or \ + is_float_dtype(get_dtype_of_scalar(end)) or \ + is_float_dtype(get_dtype_of_scalar(step)): + return torch.float32 + return torch.int64 + # ============================================================================== # Main # ============================================================================== From 9c7c7b2ec4dc90a146cc257b0140a7daeda29361 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Tue, 18 Apr 2023 17:30:48 +0000 Subject: [PATCH 2/5] Add dtype functions for reduction ops --- e2e_testing/xfail_sets.py | 3 - .../Transforms/AbstractInterpLibrary.cpp | 178 ++++++++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 94 --------- .../ReifyAbstractInterpCalculationsUtils.cpp | 2 +- .../build_tools/abstract_interp_lib_gen.py | 123 ++++++++++++ test/Dialect/Torch/refine-types-ops.mlir | 76 -------- 6 files changed, 302 insertions(+), 174 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index d5c7e9b60121..bf21e8bae35f 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -93,9 +93,6 @@ "ElementwiseAddScalar_NumToTensorFloat_Module_basic", # ERROR: assert isinstance(e, FakeTensor) "RsubInt0d_NumToTensor_Module_basic", - - # Dtype function transition failures - "MobilenetV3Module_basic", } STABLEHLO_PASS_SET = { diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 05f5f9ac4761..8812e623ff42 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9431,6 +9431,184 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sum\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sum.dim_IntList\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.int {\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.sum\"(%arg0, %arg3) : (!torch.tuple, !torch.optional) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mean.dim\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.sum\"(%arg0, %arg3) : (!torch.tuple, !torch.optional) -> !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.argmax\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" return %int4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.any.dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.amax\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.max\"(%arg0) : (!torch.tuple) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max.dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.max\"(%arg0) : (!torch.tuple) -> !torch.int\n" +" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %0 = torch.derefine %none : !torch.none to !torch.optional>\n" +" %1 = call @\"__torch_mlir_dtype_fn.aten.mean.dim\"(%arg0, %0, %false, %arg1) : (!torch.tuple, !torch.optional>, !torch.bool, !torch.optional) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.std\"(%arg0: !torch.tuple, %arg1: !torch.bool) -> !torch.int {\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %3 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.std.dim\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.std.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var\"(%arg0: !torch.tuple, %arg1: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var.dim\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.var\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.float, %arg3: !torch.optional) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__isnot__ %arg4, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" %5 = torch.prim.unchecked_cast %arg4 : !torch.optional -> !torch.int\n" +" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" %10 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.prim.TupleConstruct %0#0, %5 : !torch.int, !torch.int -> !torch.tuple\n" +" %12 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%11, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %12 : !torch.int\n" +" } else {\n" +" %10 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" torch.prim.If.yield %9 : !torch.int\n" +" } else {\n" +" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" "}\n" ""; // clang-format on diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 4a7df8d99a7c..357cf8a87ecf 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -643,91 +643,6 @@ void TypeAnalysis::visitOperation(Operation *op, return; } - if (auto sum = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - // If the input dtype is bool, the result type should be i64. - if (defaultDtype.isInteger(1)) - defaultDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - Type dtype = getDtypeOrDefault(sum.getContext(), sum.getDtype(), defaultDtype); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = dtype; - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - if (auto sumDimIntList = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - if (!defaultDtype) { - incorporateKnowledge( - sumDimIntList.getResult(), - ValueKnowledge::getTensorPessimisticValueState(op->getContext())); - return; - } - // If the input dtype is bool, the result type should be i64. - if (defaultDtype.isInteger(1)) - defaultDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - Type dtype = getDtypeOrDefault(sumDimIntList.getContext(), - sumDimIntList.getDtype(), defaultDtype); - visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.getDim(), - sumDimIntList.getKeepdim(), dtype, operands); - return; - } - if (auto meanDim = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - Type dtype = - getDtypeOrDefault(meanDim.getContext(), meanDim.getDtype(), defaultDtype); - visitReductionAlongDimIntListOp(meanDim, meanDim.getDim(), meanDim.getKeepdim(), - dtype, operands); - return; - } - if (auto argmax = dyn_cast(op)) { - Value dim = argmax.getDim(); - Type dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed); - if (dim.getType().isa()) { - visitReductionAlongAllDimsOp(op, dtype, operands); - return; - } - if (dim.getType().isa()) { - visitReductionAlongDimIntOp(argmax, argmax.getDim(), argmax.getKeepdim(), dtype, - operands); - return; - } - } - if (auto anyDim = dyn_cast(op)) { - Type dtype = operands[0]->getValue().dtype; - visitReductionAlongDimIntOp(anyDim, anyDim.getDim(), anyDim.getKeepdim(), dtype, - operands); - return; - } - if (auto maxDim = dyn_cast(op)) { - Type firstResDtype = operands[0]->getValue().dtype; - Type secondResDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - visitReductionAlongDimIntOp(maxDim, maxDim.getDim(), maxDim.getKeepdim(), - firstResDtype, operands); - visitReductionAlongDimIntOp(maxDim, maxDim.getDim(), maxDim.getKeepdim(), - secondResDtype, operands, /*resNum=*/1); - return; - } - if (auto mean = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - Type dtype = - getDtypeOrDefault(mean.getContext(), mean.getDtype(), defaultDtype); - visitReductionAlongAllDimsOp(mean, dtype, operands); - return; - } else if (isa(op)) { - Type dtype = operands[0]->getValue().dtype; - visitReductionAlongAllDimsOp(op, dtype, operands); - return; - } else if (isa(op)) { - auto input = operands[0]->getValue(); - visitReductionAlongAllDimsOp(op, input.dtype, operands); - return; - } - if (auto tensorFloat = dyn_cast(op)) { visitScalarToTensorConversionOp(tensorFloat); return; @@ -883,15 +798,6 @@ void TypeAnalysis::visitOperation(Operation *op, return; } - if (auto vectorNorm = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - Type dtype = getDtypeOrDefault(vectorNorm.getContext(), vectorNorm.getDtype(), - defaultDtype); - visitReductionAlongDimIntListOp(vectorNorm, vectorNorm.getDim(), - vectorNorm.getKeepdim(), dtype, operands); - return; - } - if (auto randIntLow = dyn_cast(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 8032a90982b4..7e3697302c5b 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -182,7 +182,7 @@ FailureOr Torch::adjustFunctionArg( // If the operand is NoneType, then we just need to derefine it to the // optional type in the function signature. if (operandType.isa()) { - assert(desiredType.isa() && + assert(!desiredType.isa() && "Don't expect library functions to have NoneType parameters"); return b.create(loc, desiredType, operand).getResult(); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 5675d890cb78..b75489378141 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -2611,6 +2611,129 @@ def aten〇arange〇start_step〡dtype(start: Union[int, float], end: Union[int, return torch.float32 return torch.int64 +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇sum〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.complex64)) +def aten〇sum〇dim_IntList〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> int: + return aten〇sum〡dtype(self_rank_dtype, dtype) + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}, dim=None) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.complex64) + + [ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dim=None, dtype=torch.int32)]) +def aten〇mean〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + result = aten〇sum〡dtype(self_rank_dtype, dtype) + assert not is_integer_dtype(result) + return result + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇argmax〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇any〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.uint8: + return self_dtype + return torch.bool + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇max〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇amax〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), keepdim: bool = False) -> int: + return aten〇max〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇max〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]: + return aten〇max〡dtype(self_rank_dtype), torch.int64 + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64) + + [ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dtype=torch.int32)]) +def aten〇mean〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None) -> int: + return aten〇mean〇dim〡dtype(self_rank_dtype, dim=None, dtype=dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇std〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = True) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.complex64: + return torch.float32 + if self_dtype == torch.complex128: + return torch.float64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None)) +def aten〇std〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇std〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇var〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = True) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None)) +def aten〇var〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇var〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[], correction=0.0)) +def prims〇var〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int]], correction: float, output_dtype: Optional[int] = None) -> int: + return aten〇std〡dtype(inp_rank_dtype) + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.complex64, torch.complex128}, dtype=torch.float64) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16, torch.float16, torch.float32, torch.float64}, dtype=torch.complex128) + + [ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dtype=torch.int32)]) +def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Union[int, float] = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + if dtype is not None: + assert not is_integer_dtype(dtype) + if is_complex_dtype(self_dtype): + assert is_complex_dtype(dtype) + return aten〇std〡dtype((self_rank, dtype)) + assert not is_complex_dtype(dtype) + return dtype + return aten〇std〡dtype(self_rank_dtype) + # ============================================================================== # Main # ============================================================================== diff --git a/test/Dialect/Torch/refine-types-ops.mlir b/test/Dialect/Torch/refine-types-ops.mlir index 6fc29daaba08..06ec0d107f44 100644 --- a/test/Dialect/Torch/refine-types-ops.mlir +++ b/test/Dialect/Torch/refine-types-ops.mlir @@ -3,58 +3,6 @@ // This file is for tests for individual ops that require a new transfer // function (i.e. new code called from visitOperation). -// ----- -// CHECK-LABEL: func.func @aten.arange.start$int64_dtype( -// CHECK-SAME: %[[START:.*]]: !torch.int, -// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T:.*]] = torch.aten.arange.start -// CHECK-SAME: %[[START]], %[[END]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : -// CHECK-SAME: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -// CHECK-SAME: -> !torch.vtensor<*,si64> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,si64> to !torch.vtensor -// CHECK: return %[[RET]] : !torch.vtensor -func.func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !torch.vtensor { - %none = torch.constant.none - %ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @aten.arange.start$float32_dtype( -// CHECK-SAME: %[[START:.*]]: !torch.float, -// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T:.*]] = torch.aten.arange.start -// CHECK-SAME: %[[START]], %[[END]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : -// CHECK-SAME: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -// CHECK-SAME: -> !torch.vtensor<*,f32> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[RET]] : !torch.vtensor -func.func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) -> !torch.vtensor { - %none = torch.constant.none - %ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @aten.arange.start$specified_dtype( -// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { -// CHECK: %[[CST6:.*]] = torch.constant.int 6 -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T:.*]] = torch.aten.arange -// CHECK-SAME: %[[END]], %[[CST6]], %[[NONE]], %[[NONE]], %[[NONE]] : -// CHECK-SAME: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -// CHECK-SAME: -> !torch.vtensor<*,f32> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[RET]] : !torch.vtensor -func.func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor { - %int6 = torch.constant.int 6 - %none = torch.constant.none - %ret = torch.aten.arange %end, %int6, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor - return %ret : !torch.vtensor -} - // ----- // CHECK-LABEL: func.func @torch.aten.linear( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, @@ -68,30 +16,6 @@ func.func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vte return %1 : !torch.vtensor } -// ----- -// CHECK-LABEL: func.func @aten.sum.dim_IntList( -// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,si64>) -> !torch.vtensor { -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 -// CHECK: %[[DIMLIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT_NEG1]] -// CHECK-SAME: : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[RET:.*]] = torch.aten.sum.dim_IntList %[[T]], %[[DIMLIST]], %[[FALSE]], %[[NONE]] -// CHECK-SAME: : !torch.vtensor<*,si64>, !torch.list, !torch.bool, !torch.none -// CHECK-SAME: -> !torch.vtensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,si64> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @aten.sum.dim_IntList(%t: !torch.vtensor<*,si64>) -> !torch.vtensor { - %false = torch.constant.bool false - %none = torch.constant.none - %int0 = torch.constant.int 0 - %int-1 = torch.constant.int -1 - %dimList = torch.prim.ListConstruct %int0, %int-1 : (!torch.int, !torch.int) -> !torch.list - %ret = torch.aten.sum.dim_IntList %t, %dimList, %false, %none : !torch.vtensor<*,si64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor - return %ret : !torch.vtensor -} - // ----- // CHECK-LABEL: func.func @torch.aten.zeros( // CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor { From bf91c220731e61fcd7f9097c24c7dc9f94532335 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Mon, 24 Apr 2023 21:34:23 +0000 Subject: [PATCH 3/5] Add dtype functions for scalar to tensor ops --- .../Transforms/AbstractInterpLibrary.cpp | 36 +++++++++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 11 ------ .../build_tools/abstract_interp_lib_gen.py | 27 ++++++++++++++ test/Dialect/Torch/refine-types-ops.mlir | 32 ----------------- 4 files changed, 63 insertions(+), 43 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 8812e623ff42..40dc27f04a63 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9609,6 +9609,42 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tensor.float\"(%arg0: !torch.float, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tensor.int\"(%arg0: !torch.int, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tensor.bool\"(%arg0: !torch.bool, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" "}\n" ""; // clang-format on diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 357cf8a87ecf..92c2bbe144a1 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -643,17 +643,6 @@ void TypeAnalysis::visitOperation(Operation *op, return; } - if (auto tensorFloat = dyn_cast(op)) { - visitScalarToTensorConversionOp(tensorFloat); - return; - } else if (auto tensorInt = dyn_cast(op)) { - visitScalarToTensorConversionOp(tensorInt); - return; - } else if (auto tensorBool = dyn_cast(op)) { - visitScalarToTensorConversionOp(tensorBool); - return; - } - if (auto tensor = dyn_cast(op)) { visitAtenTensorOp(tensor); return; diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index b75489378141..3fbc09ec7efe 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -2734,6 +2734,33 @@ def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni return dtype return aten〇std〡dtype(self_rank_dtype) +@check_dtype_function([Invocation(0.0), + Invocation(0.0, dtype=torch.int32), + Invocation(0.0, dtype=torch.float16), + Invocation(0.0, dtype=torch.complex64)]) +def aten〇tensor〇float〡dtype(t: float, dtype: Optional[int] = None, device: Optional[device] = None, requires_grad: bool = False) -> int: + if dtype is None: + return torch.float32 + return dtype + +@check_dtype_function([Invocation(0), + Invocation(0, dtype=torch.int32), + Invocation(0, dtype=torch.float16), + Invocation(0, dtype=torch.complex64)]) +def aten〇tensor〇int〡dtype(t: int, dtype: Optional[int] = None, device: Optional[device] = None, requires_grad: bool = False) -> int: + if dtype is None: + return torch.int64 + return dtype + +@check_dtype_function([Invocation(True), + Invocation(True, dtype=torch.int32), + Invocation(True, dtype=torch.float16), + Invocation(True, dtype=torch.complex64)]) +def aten〇tensor〇bool〡dtype(t: bool, dtype: Optional[int] = None, device: Optional[device] = None, requires_grad: bool = False) -> int: + if dtype is None: + return torch.bool + return dtype + # ============================================================================== # Main # ============================================================================== diff --git a/test/Dialect/Torch/refine-types-ops.mlir b/test/Dialect/Torch/refine-types-ops.mlir index 06ec0d107f44..633b7757966e 100644 --- a/test/Dialect/Torch/refine-types-ops.mlir +++ b/test/Dialect/Torch/refine-types-ops.mlir @@ -115,38 +115,6 @@ func.func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indices: return %ret: !torch.tensor } -// ----- -// CHECK-LABEL: func.func @torch.aten.tensor.float( -// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor { - %none = torch.constant.none - %false = torch.constant.bool false - %ret = torch.aten.tensor.float %t, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.tensor.float$specified_dtype( -// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[CST11:.*]] = torch.constant.int 11 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[CST11]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,i1> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,i1> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor { - %none = torch.constant.none - %int11 = torch.constant.int 11 - %false = torch.constant.bool false - %ret = torch.aten.tensor.float %t, %int11, %none, %false : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor - return %ret : !torch.tensor -} - // ----- // CHECK-LABEL: func.func @torch.aten.softmax.int( // CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>, From f5e9d5fd4fe3dc03e24442fdbe7f3ccbcb5eb8d2 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Mon, 24 Apr 2023 22:45:35 +0000 Subject: [PATCH 4/5] Add dtype functions for constant tensor allocation ops --- .../Transforms/AbstractInterpLibrary.cpp | 187 ++++++++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 50 ----- .../build_tools/abstract_interp_lib_gen.py | 126 ++++++++++++ test/Dialect/Torch/refine-types-ops.mlir | 17 -- test/Dialect/Torch/refine-types.mlir | 21 -- 5 files changed, 313 insertions(+), 88 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 40dc27f04a63..b0cce34ba0c9 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9645,6 +9645,193 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.zeros\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ones\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.empty.memory_format\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.full\"(%arg0: !torch.list, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.zeros_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ones_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.empty_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.full_like\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_zeros\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_ones\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_empty\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_empty_strided\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rand_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.randn_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._to_copy\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" "}\n" ""; // clang-format on diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 92c2bbe144a1..745f0cca388f 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -648,56 +648,6 @@ void TypeAnalysis::visitOperation(Operation *op, return; } - if (auto zeros = dyn_cast(op)) { - visitConstantTensorAllocOp(zeros, /*dataType=*/{}); - return; - } else if (auto ones = dyn_cast(op)) { - visitConstantTensorAllocOp(ones, /*dataType=*/{}); - return; - } else if (auto emptyMemoryFormat = dyn_cast(op)) { - visitConstantTensorAllocOp(emptyMemoryFormat, - /*dataType=*/{}); - return; - } else if (auto full = dyn_cast(op)) { - visitConstantTensorAllocOp( - full, /*dataType=*/full.getFillValue().getType()); - return; - } else if (auto zerosLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(zerosLike, operands); - return; - } else if (auto onesLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(onesLike, operands); - return; - } else if (auto emptyLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(emptyLike, operands); - return; - } else if (auto fullLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(fullLike, operands); - return; - } else if (auto newZeros = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newZeros, operands); - return; - } else if (auto newOnes = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newOnes, operands); - return; - } else if (auto newEmpty = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newEmpty, operands); - return; - } else if (auto newEmptyStrided = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newEmptyStrided, - operands); - return; - } else if (auto randLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(randLike, operands); - return; - } else if (auto randLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(randLike, operands); - return; - } else if (auto toCopy = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(toCopy, operands); - return; - } - if (auto toDtype = dyn_cast(op)) { visitAtenToDtypeLikeOp(toDtype, operands); return; diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 3fbc09ec7efe..7c57bc9a1959 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -2761,6 +2761,132 @@ def aten〇tensor〇bool〡dtype(t: bool, dtype: Optional[int] = None, device: O return torch.bool return dtype +@check_dtype_function([Invocation([1]), + Invocation([1], dtype=torch.int32), + Invocation([1], dtype=torch.float16), + Invocation([1], dtype=torch.complex64)]) +def aten〇zeros〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.float32 if dtype is None else dtype + +@check_dtype_function([Invocation([1]), + Invocation([1], dtype=torch.int32), + Invocation([1], dtype=torch.float16), + Invocation([1], dtype=torch.complex64)]) +def aten〇ones〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.float32 if dtype is None else dtype + +@check_dtype_function([Invocation([1]), + Invocation([1], dtype=torch.int32), + Invocation([1], dtype=torch.float16), + Invocation([1], dtype=torch.complex64)]) +def aten〇empty〇memory_format〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + return torch.float32 if dtype is None else dtype + +@check_dtype_function([Invocation([1], 0.0), + Invocation([1], 0), + Invocation([1], 0.0, dtype=torch.int32), + Invocation([1], 0.0, dtype=torch.float16), + Invocation([1], 0.0, dtype=torch.complex64)]) +def aten〇full〡dtype(size: List[int], fill_value: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is not None: + return dtype + fill_value_dtype = get_dtype_of_scalar(fill_value) + if is_float_dtype(fill_value_dtype): + return torch.float32 + return fill_value_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇zeros_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇ones_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇empty_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.complex64)) +def aten〇full_like〡dtype(self_rank_dtype: Tuple[int, int], fill_value: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.complex64)) +def aten〇new_zeros〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.complex64)) +def aten〇new_ones〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.complex64)) +def aten〇new_empty〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1], dtype=torch.complex64)) +def aten〇new_empty_strided〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇rand_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇randn_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇_to_copy〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + # ============================================================================== # Main # ============================================================================== diff --git a/test/Dialect/Torch/refine-types-ops.mlir b/test/Dialect/Torch/refine-types-ops.mlir index 633b7757966e..771ce5a9582a 100644 --- a/test/Dialect/Torch/refine-types-ops.mlir +++ b/test/Dialect/Torch/refine-types-ops.mlir @@ -16,23 +16,6 @@ func.func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vte return %1 : !torch.vtensor } -// ----- -// CHECK-LABEL: func.func @torch.aten.zeros( -// CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ZEROS]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor { - %none = torch.constant.none - %int2 = torch.constant.int 2 - %sizesList = torch.prim.ListConstruct %dim0, %int2 : (!torch.int, !torch.int) -> !torch.list - %ret = torch.aten.zeros %sizesList, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor - return %ret : !torch.tensor -} - // ----- // CHECK-LABEL: func.func @torch.aten.type_as( // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>, diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index 8af362cd0d64..df2722037496 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -40,27 +40,6 @@ func.func @forward(%arg0: !torch.bool, %arg1: !torch.tensor) { // ----- -// CHECK-LABEL: func.func @torch.aten.zeros_like( -// CHECK-SAME: %[[arg:.*]]: !torch.vtensor) { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[CPU:.*]] = torch.constant.device "cpu" -// CHECK: %[[ZEROS:.*]] = torch.aten.zeros_like %[[arg]], %[[INT6]], %[[INT0]], %[[CPU]], %[[FALSE]], %[[INT1]] : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor<*,f32> -// CHECK: return -func.func @torch.aten.zeros_like(%arg: !torch.vtensor) { - %int6 = torch.constant.int 6 - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %cpu = torch.constant.device "cpu" - %2 = torch.aten.zeros_like %arg, %int6, %int0, %cpu, %false, %int1 : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor - return -} - -// ----- - // The data-flow analysis does not always propagate information to the entire graph. // This results in some lattice elements being uninitialized, which must be properly // handled when using the lattice elements to rewrite the graph. From 384626fa6f9e98df7218fca4f6d26ab625f94df6 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Mon, 24 Apr 2023 23:10:20 +0000 Subject: [PATCH 5/5] Add dtype functions for type conversion ops --- .../Transforms/AbstractInterpLibrary.cpp | 29 ++++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 29 -------------- .../build_tools/abstract_interp_lib_gen.py | 39 +++++++++++++++++++ test/Dialect/Torch/refine-types-ops.mlir | 29 -------------- 4 files changed, 68 insertions(+), 58 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index b0cce34ba0c9..c767e2ba67fd 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9832,6 +9832,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.dtype\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" +" return %arg1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.nvprims.convert_element_type\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" return %arg1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.dtype_layout\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.bool, %arg7: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.device\"(%arg0: !torch.tuple, %arg1: !torch.Device, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool, %arg5: !torch.optional) -> !torch.int {\n" +" return %arg2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.other\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.type_as\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" "}\n" ""; // clang-format on diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 745f0cca388f..8fa4a2feaa07 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -648,35 +648,6 @@ void TypeAnalysis::visitOperation(Operation *op, return; } - if (auto toDtype = dyn_cast(op)) { - visitAtenToDtypeLikeOp(toDtype, operands); - return; - } - - if (auto primsConvertElementType = dyn_cast(op)) { - visitAtenToDtypeLikeOp(primsConvertElementType, - operands); - return; - } - - if (auto toDtypeLayout = dyn_cast(op)) { - visitAtenToDtypeLikeOp(toDtypeLayout, operands); - return; - } - - if (auto toDtype = dyn_cast(op)) { - visitAtenToDtypeLikeOp(toDtype, operands); - return; - } - - if (auto toOther = dyn_cast(op)) { - visitTypeConversionOp(toOther, operands); - return; - } else if (auto typeAs = dyn_cast(op)) { - visitTypeConversionOp(typeAs, operands); - return; - } - if (auto cat = dyn_cast(op)) { visitAtenCatLikeOp(cat, operands); return; diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 7c57bc9a1959..532a0c594a82 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -2887,6 +2887,45 @@ def aten〇_to_copy〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[in self_rank, self_dtype = self_rank_dtype return self_dtype if dtype is None else dtype +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇to〇dtype〡dtype(self_rank_dtype: Tuple[int, int], dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int: + return dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def nvprims〇convert_element_type〡dtype(a_rank_dtype: Tuple[int, int], dtype: int) -> int: + return dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇to〇dtype_layout〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, device="meta", dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, device="meta", dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, device="meta", dtype=torch.complex64)) +def aten〇to〇device〡dtype(self_rank_dtype: Tuple[int, int], device: device, dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int: + return dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇to〇other〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int: + other_rank, other_dtype = other_rank_dtype + return other_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇type_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + return other_dtype + # ============================================================================== # Main # ============================================================================== diff --git a/test/Dialect/Torch/refine-types-ops.mlir b/test/Dialect/Torch/refine-types-ops.mlir index 771ce5a9582a..686669003a2c 100644 --- a/test/Dialect/Torch/refine-types-ops.mlir +++ b/test/Dialect/Torch/refine-types-ops.mlir @@ -16,18 +16,6 @@ func.func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vte return %1 : !torch.vtensor } -// ----- -// CHECK-LABEL: func.func @torch.aten.type_as( -// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>, -// CHECK-SAME: %[[OTHER:.*]]: !torch.tensor<[?,2],f32>) -> !torch.tensor { -// CHECK: %[[RET:.*]] = torch.aten.type_as %[[INPUT]], %[[OTHER]] : !torch.tensor<[?],si64>, !torch.tensor<[?,2],f32> -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<[?,2],f32>) -> !torch.tensor { - %ret = torch.aten.type_as %self, %other : !torch.tensor<[?], si64>, !torch.tensor<[?,2],f32> -> !torch.tensor - return %ret: !torch.tensor -} - // ----- // CHECK-LABEL: func.func @torch.aten.cat( // CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>, @@ -126,23 +114,6 @@ func.func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, return %ret : !torch.tensor } -// ----- -// CHECK-LABEL: func.func @torch.aten.to.dtype( -// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor -// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype -// CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : -// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -// CHECK-SAME: -> !torch.tensor<*,si64> -// CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK-NEXT: return %[[RES]] : !torch.tensor -func.func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{ - %none = torch.constant.none - %false = torch.constant.bool false - %int4 = torch.constant.int 4 - %0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor - return %0 : !torch.tensor -} - // ----- // CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar( // CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor {