Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dtype functions for rand* and var_mean* ops #2089

Merged
merged 2 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5935,6 +5935,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" return %3 : !torch.list<int>\n"
" }\n"
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.testing_framework._convert_dtype_to_int(%arg0: !torch.int) -> !torch.int {\n"
" return %arg0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.triu\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -9861,6 +9864,135 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.int {\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int4 = torch.constant.int 4\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" torch.prim.If.yield %int4 : !torch.int\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n"
" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.If.yield %2 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.randn\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2) : (!torch.int) -> !torch.bool\n"
" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.If.yield %2 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.randn.generator\"(%arg0: !torch.list<int>, %arg1: !torch.any, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2) : (!torch.int) -> !torch.bool\n"
" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.If.yield %2 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.var_mean.correction\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.union<float, int, none>, %arg3: !torch.bool) -> !torch.tuple<int, int> {\n"
" %int7 = torch.constant.int 7\n"
" %int10 = torch.constant.int 10\n"
" %int6 = torch.constant.int 6\n"
" %int9 = torch.constant.int 9\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.tuple<int, int>) {\n"
" %5 = torch.prim.TupleConstruct %int6, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" torch.prim.If.yield %5 : !torch.tuple<int, int>\n"
" } else {\n"
" %5 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n"
" %6 = torch.prim.If %5 -> (!torch.tuple<int, int>) {\n"
" %7 = torch.prim.TupleConstruct %int7, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" torch.prim.If.yield %7 : !torch.tuple<int, int>\n"
" } else {\n"
" %7 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" torch.prim.If.yield %7 : !torch.tuple<int, int>\n"
" }\n"
" torch.prim.If.yield %6 : !torch.tuple<int, int>\n"
" }\n"
" return %4 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.var_mean\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.bool) -> !torch.tuple<int, int> {\n"
" %int7 = torch.constant.int 7\n"
" %int10 = torch.constant.int 10\n"
" %int6 = torch.constant.int 6\n"
" %int9 = torch.constant.int 9\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.tuple<int, int>) {\n"
" %5 = torch.prim.TupleConstruct %int6, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" torch.prim.If.yield %5 : !torch.tuple<int, int>\n"
" } else {\n"
" %5 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n"
" %6 = torch.prim.If %5 -> (!torch.tuple<int, int>) {\n"
" %7 = torch.prim.TupleConstruct %int7, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" torch.prim.If.yield %7 : !torch.tuple<int, int>\n"
" } else {\n"
" %7 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" torch.prim.If.yield %7 : !torch.tuple<int, int>\n"
" }\n"
" torch.prim.If.yield %6 : !torch.tuple<int, int>\n"
" }\n"
" return %4 : !torch.tuple<int, int>\n"
" }\n"
"}\n"
"";
// clang-format on
Expand Down
51 changes: 0 additions & 51 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,6 @@ static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) {
return failed(result) ? Type() : *result;
}

static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype,
Type defaultDtype) {
int64_t dtypeInt;
if (matchPattern(optionalDtype, m_TorchConstantInt(&dtypeInt)))
return getTypeForDTypeInteger(context, dtypeInt);
else if (optionalDtype.getType().isa<Torch::NoneType>())
return defaultDtype;
return Type();
}

// Get the kind enum for `ValueKnowledge.kind`.
static torch_upstream::TypeKind getTypeKind(Type type) {
if (type.isa<NumberType>())
Expand Down Expand Up @@ -708,47 +698,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}

if (auto randIntLow = dyn_cast<AtenRandintLowOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type defaultDtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
knowledge.dtype =
getDtypeOrDefault(op->getContext(), randIntLow.getDtype(), defaultDtype);
incorporateKnowledge(randIntLow.getResult(), knowledge);
return;
}

if (isa<AtenVarMeanCorrectionOp, AtenVarMeanOp>(op)) {
auto input = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = input.dtype;
incorporateKnowledge(op->getResult(0), knowledge);
incorporateKnowledge(op->getResult(1), knowledge);
return;
}

if (auto randn = dyn_cast<AtenRandnOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type defaultDtype = Float32Type::get(op->getContext());
knowledge.dtype =
getDtypeOrDefault(op->getContext(), randn.getDtype(), defaultDtype);
incorporateKnowledge(randn.getResult(), knowledge);
return;
}

if (auto randnGenerator = dyn_cast<AtenRandnGeneratorOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type defaultDtype = Float32Type::get(op->getContext());
knowledge.dtype = getDtypeOrDefault(op->getContext(),
randnGenerator.getDtype(), defaultDtype);
incorporateKnowledge(randnGenerator.getResult(), knowledge);
return;
}

if (auto bucketize = dyn_cast<AtenBucketizeTensorOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2926,6 +2926,56 @@ def aten〇type_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: T
other_rank, other_dtype = other_rank_dtype
return other_dtype

@check_dtype_function([Invocation(low=0, high=10, size=[1]),
Invocation(low=0, high=10, size=[1], dtype=torch.float32),
Invocation(low=0, high=10, size=[1], dtype=torch.int32),
ErrorInvocation(low=0, high=10, size=[1], dtype=torch.complex64)])
def aten〇randint〇low〡dtype(low: int, high: int, size: List[int], dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
if dtype is None:
return torch.int64
assert not is_complex_dtype(dtype)
return dtype

@check_dtype_function([Invocation(size=[1]),
Invocation(size=[1], dtype=torch.float32),
ErrorInvocation(size=[1], dtype=torch.int32),
Invocation(size=[1], dtype=torch.complex64)])
def aten〇randn〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
if dtype is None:
return torch.float32
assert not is_integer_dtype(dtype)
return dtype

@check_dtype_function([Invocation(size=[1], generator=None),
Invocation(size=[1], generator=None, dtype=torch.float32),
ErrorInvocation(size=[1], generator=None, dtype=torch.int32),
Invocation(size=[1], generator=None, dtype=torch.complex64)])
def aten〇randn〇generator〡dtype(size: List[int], generator: Any, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
if dtype is None:
return torch.float32
assert not is_integer_dtype(dtype)
return dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types=all_integer_dtypes()))
def aten〇var_mean〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> Tuple[int, int]:
self_rank, self_dtype = self_rank_dtype
assert not is_integer_dtype(self_dtype)
if self_dtype == torch.complex64:
return torch.float32, self_dtype
if self_dtype == torch.complex128:
return torch.float64, self_dtype
return self_dtype, self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types=all_integer_dtypes()))
def aten〇var_mean〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = True) -> Tuple[int, int]:
self_rank, self_dtype = self_rank_dtype
assert not is_integer_dtype(self_dtype)
if self_dtype == torch.complex64:
return torch.float32, self_dtype
if self_dtype == torch.complex128:
return torch.float64, self_dtype
return self_dtype, self_dtype

# ==============================================================================
# Main
# ==============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,15 @@ def decorator(f):
return f
return decorator

@torch.jit.script
def _convert_dtype_to_int(dtype: torch.dtype) -> int:
"""Convert a PyTorch `dtype` into its underlying `int` representation.

This works because in TorchScript there is no special type for `dtypes`;
they are simply `int`s.
"""
return dtype

def check_dtype_function(invocations: List[Invocation]):
"""Decorator that automatically tests a dtype function.

Expand All @@ -281,7 +290,12 @@ def decorator(f):
golden_dtype = torch.tensor([]).to(type(golden_result)).dtype
else:
raise ValueError(f"Unhandled return type {type(golden_result)}")
if result_dtype != golden_dtype:
# Some dtype funtions have default `dtype` parameters, which are
# represented as `int` values in the registry. In order to
# support returning the default `int` value, the comparisons of
# the result and golden dtypes are done using their underlying
# `int` representation.
if _convert_dtype_to_int(result_dtype) != _convert_dtype_to_int(golden_dtype):
_report(f, invocation, f"Expected result dtype {golden_dtype}, got {result_dtype}")
return f
return decorator