Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dtype functions for floating point ops #1813

Merged
merged 1 commit into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build_tools/python_deploy/build_linux_packages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ function test_in_tree() {
python -m e2e_testing.main --config=lazy_tensor_core -v

echo ":::: Run TorchDynamo e2e integration tests"
python -m e2e_testing.main --config=torchdynamo -v
python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, thanks!

}

function setup_venv() {
Expand Down
132 changes: 106 additions & 26 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7350,6 +7350,112 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @__torch__._get_dtype_of_floating_point_op(%arg0: !torch.int) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %int5 = torch.constant.int 5\n"
" %int15 = torch.constant.int 15\n"
" %int7 = torch.constant.int 7\n"
" %0 = torch.prim.ListConstruct %int7, %int15, %int5 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
" torch.prim.If.yield %arg0 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.reciprocal\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.log\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.log2\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.erf\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int11 = torch.constant.int 11\n"
" %int3 = torch.constant.int 3\n"
" %int4 = torch.constant.int 4\n"
" %0 = torch.prim.ListConstruct %int4, %int3, %int11 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int11 = torch.constant.int 11\n"
" %int3 = torch.constant.int 3\n"
" %int4 = torch.constant.int 4\n"
" %0 = torch.prim.ListConstruct %int4, %int3, %int11 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.prims.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
Expand Down Expand Up @@ -7520,32 +7626,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %2) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %int5 = torch.constant.int 5\n"
" %int15 = torch.constant.int 15\n"
" %true = torch.constant.bool true\n"
" %int7 = torch.constant.int 7\n"
" %0 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %4 = torch.aten.eq.int %arg1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" }\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %4 = torch.aten.eq.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" }\n"
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
" torch.prim.If.yield %arg1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" }\n"
" return %3 : !torch.int\n"
" }\n"
"}\n"
"";
// clang-format on
Expand Down
17 changes: 0 additions & 17 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,23 +672,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}

// Dtype is always float32, except for bfloat16, float16, float64 and nullptr.
if (isa<AtenTanhOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenLog1pOp,
AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp,
PrimsSqrtOp>(op)) {
ValueKnowledge knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type dtype = operands[0]->getValue().dtype;
if (dtype) {
knowledge.dtype = Float32Type::get(op->getContext());
if (dtype.isa<BFloat16Type, Float16Type, Float64Type>())
knowledge.dtype = dtype;
}
incorporateKnowledge(op->getResult(0), knowledge);
return;
}

// Take dtype from second operand.
if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) {
auto self = operands[1]->getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1026,19 +1026,107 @@ def _get_invocations_for_op_with_tensor_arg_followed_by(*args):
dtype function instead of using this helper function.
"""
return [
Invocation(NonZeroDTensorWithDtype(torch.float32), *args),
Invocation(NonZeroDTensorWithDtype(torch.float64), *args),
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args),
Invocation(NonZeroDTensorWithDtype(torch.int64), *args),
Invocation(NonZeroDTensorWithDtype(torch.int32), *args),
Invocation(NonZeroDTensorWithDtype(torch.bool), *args),
Invocation(ZeroDTensorWithDtype(torch.float32), *args),
Invocation(ZeroDTensorWithDtype(torch.float64), *args),
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args),
Invocation(ZeroDTensorWithDtype(torch.int64), *args),
Invocation(ZeroDTensorWithDtype(torch.int32), *args),
Invocation(ZeroDTensorWithDtype(torch.bool), *args),
]
Invocation(NonZeroDTensorWithDtype(torch.float32), *args),
Invocation(NonZeroDTensorWithDtype(torch.float64), *args),
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args),
Invocation(NonZeroDTensorWithDtype(torch.int64), *args),
Invocation(NonZeroDTensorWithDtype(torch.int32), *args),
Invocation(NonZeroDTensorWithDtype(torch.bool), *args),
Invocation(ZeroDTensorWithDtype(torch.float32), *args),
Invocation(ZeroDTensorWithDtype(torch.float64), *args),
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args),
Invocation(ZeroDTensorWithDtype(torch.int64), *args),
Invocation(ZeroDTensorWithDtype(torch.int32), *args),
Invocation(ZeroDTensorWithDtype(torch.bool), *args),
]

def _get_invocations_for_fp_only_op_with_tensor_arg_followed_by(*args):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will be making a PR soon that improves the testing helper functions quite a bit, so if you're working on another set of ops, I would wait until that lands

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! I haven't started working on the next task. Will wait for your PR before further development.

"""Generate invocations for floating point only op."""
return [
Invocation(NonZeroDTensorWithDtype(torch.float32), *args),
Invocation(NonZeroDTensorWithDtype(torch.float64), *args),
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args),
ErrorInvocation(NonZeroDTensorWithDtype(torch.int64), *args),
ErrorInvocation(NonZeroDTensorWithDtype(torch.int32), *args),
ErrorInvocation(NonZeroDTensorWithDtype(torch.bool), *args),
Invocation(ZeroDTensorWithDtype(torch.float32), *args),
Invocation(ZeroDTensorWithDtype(torch.float64), *args),
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args),
ErrorInvocation(ZeroDTensorWithDtype(torch.int64), *args),
ErrorInvocation(ZeroDTensorWithDtype(torch.int32), *args),
ErrorInvocation(ZeroDTensorWithDtype(torch.bool), *args),
]

def _get_dtype_of_floating_point_op(input_dtype: int) -> int:
if input_dtype in (torch.float64, torch.bfloat16, torch.float16):
return input_dtype
return torch.float32

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇exp〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇sin〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇cos〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇sigmoid〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇reciprocal〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇log〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇log2〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇log1p〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇rsqrt〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇erf〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by())
def aten〇softplus〡dtype(self_rank: int, self_dtype: int, beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int:
assert self_dtype not in (torch.int64, torch.int32, torch.bool)
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by([0]))
def aten〇frobenius_norm〇dim〡dtype(self_rank: int, self_dtype: int, dim: List[int], keepdim: bool = False) -> int:
assert self_dtype not in (torch.int64, torch.int32, torch.bool)
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def prims〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇all〡dtype(self_rank: int, self_dtype: int) -> int:
Expand Down Expand Up @@ -1167,13 +1255,6 @@ def aten〇floor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int
def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int:
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)])

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int:
if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16:
return self_dtype
else:
return torch.float32

# ==============================================================================
# Main
# ==============================================================================
Expand Down
34 changes: 0 additions & 34 deletions test/Dialect/Torch/refine-types-branch.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -117,37 +117,3 @@ func.func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.option
} : (!torch.int, !torch.bool, !torch.optional<tensor>) -> (!torch.optional<tensor>)
return %ret: !torch.optional<tensor>
}

// -----

// CHECK-LABEL: func.func @f
// CHECK: %[[ATEN:.*]] = torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
cf.br ^bb1(%cast: !torch.vtensor)
^bb1(%arg1: !torch.vtensor):
%1 = torch.aten.cos %arg1 : !torch.vtensor -> !torch.vtensor
return %1 : !torch.vtensor
}

// -----

// CHECK-LABEL: func.func @f
// CHECK: func.func private @callee
// CHECK-NEXT: torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
func.func @f() {
builtin.module {
func.func private @callee(%arg0: !torch.vtensor) {
%1 = torch.aten.cos %arg0 : !torch.vtensor -> !torch.vtensor
return
}
func.func @caller(%arg0: !torch.vtensor<*,f32>) {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
call @callee(%cast) : (!torch.vtensor) -> ()
return
}
}
return
}
Loading