Skip to content

Commit

Permalink
Add dtype functions for floating point ops (llvm#1813)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus authored and ramiro050 committed Jan 26, 2023
1 parent fe06ca5 commit 8b8d071
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 149 deletions.
2 changes: 1 addition & 1 deletion build_tools/python_deploy/build_linux_packages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ function test_in_tree() {
python -m e2e_testing.main --config=lazy_tensor_core -v

echo ":::: Run TorchDynamo e2e integration tests"
python -m e2e_testing.main --config=torchdynamo -v
python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic
}

function setup_venv() {
Expand Down
132 changes: 106 additions & 26 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7356,6 +7356,112 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @__torch__._get_dtype_of_floating_point_op(%arg0: !torch.int) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %int5 = torch.constant.int 5\n"
" %int15 = torch.constant.int 15\n"
" %int7 = torch.constant.int 7\n"
" %0 = torch.prim.ListConstruct %int7, %int15, %int5 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
" torch.prim.If.yield %arg0 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.reciprocal\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.log\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.log2\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.erf\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int11 = torch.constant.int 11\n"
" %int3 = torch.constant.int 3\n"
" %int4 = torch.constant.int 4\n"
" %0 = torch.prim.ListConstruct %int4, %int3, %int11 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int11 = torch.constant.int 11\n"
" %int3 = torch.constant.int 3\n"
" %int4 = torch.constant.int 4\n"
" %0 = torch.prim.ListConstruct %int4, %int3, %int11 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.prims.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
Expand Down Expand Up @@ -7526,32 +7632,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %2) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %int5 = torch.constant.int 5\n"
" %int15 = torch.constant.int 15\n"
" %true = torch.constant.bool true\n"
" %int7 = torch.constant.int 7\n"
" %0 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %4 = torch.aten.eq.int %arg1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" }\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %4 = torch.aten.eq.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" }\n"
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
" torch.prim.If.yield %arg1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" }\n"
" return %3 : !torch.int\n"
" }\n"
"}\n"
"";
// clang-format on
Expand Down
17 changes: 0 additions & 17 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,23 +676,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}

// Dtype is always float32, except for bfloat16, float16, float64 and nullptr.
if (isa<AtenTanhOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenLog1pOp,
AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp,
PrimsSqrtOp>(op)) {
ValueKnowledge knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type dtype = operands[0]->getValue().dtype;
if (dtype) {
knowledge.dtype = Float32Type::get(op->getContext());
if (dtype.isa<BFloat16Type, Float16Type, Float64Type>())
knowledge.dtype = dtype;
}
incorporateKnowledge(op->getResult(0), knowledge);
return;
}

// Take dtype from second operand.
if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) {
auto self = operands[1]->getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1032,19 +1032,107 @@ def _get_invocations_for_op_with_tensor_arg_followed_by(*args):
dtype function instead of using this helper function.
"""
return [
Invocation(NonZeroDTensorWithDtype(torch.float32), *args),
Invocation(NonZeroDTensorWithDtype(torch.float64), *args),
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args),
Invocation(NonZeroDTensorWithDtype(torch.int64), *args),
Invocation(NonZeroDTensorWithDtype(torch.int32), *args),
Invocation(NonZeroDTensorWithDtype(torch.bool), *args),
Invocation(ZeroDTensorWithDtype(torch.float32), *args),
Invocation(ZeroDTensorWithDtype(torch.float64), *args),
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args),
Invocation(ZeroDTensorWithDtype(torch.int64), *args),
Invocation(ZeroDTensorWithDtype(torch.int32), *args),
Invocation(ZeroDTensorWithDtype(torch.bool), *args),
]
Invocation(NonZeroDTensorWithDtype(torch.float32), *args),
Invocation(NonZeroDTensorWithDtype(torch.float64), *args),
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args),
Invocation(NonZeroDTensorWithDtype(torch.int64), *args),
Invocation(NonZeroDTensorWithDtype(torch.int32), *args),
Invocation(NonZeroDTensorWithDtype(torch.bool), *args),
Invocation(ZeroDTensorWithDtype(torch.float32), *args),
Invocation(ZeroDTensorWithDtype(torch.float64), *args),
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args),
Invocation(ZeroDTensorWithDtype(torch.int64), *args),
Invocation(ZeroDTensorWithDtype(torch.int32), *args),
Invocation(ZeroDTensorWithDtype(torch.bool), *args),
]

def _get_invocations_for_fp_only_op_with_tensor_arg_followed_by(*args):
"""Generate invocations for floating point only op."""
return [
Invocation(NonZeroDTensorWithDtype(torch.float32), *args),
Invocation(NonZeroDTensorWithDtype(torch.float64), *args),
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args),
ErrorInvocation(NonZeroDTensorWithDtype(torch.int64), *args),
ErrorInvocation(NonZeroDTensorWithDtype(torch.int32), *args),
ErrorInvocation(NonZeroDTensorWithDtype(torch.bool), *args),
Invocation(ZeroDTensorWithDtype(torch.float32), *args),
Invocation(ZeroDTensorWithDtype(torch.float64), *args),
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args),
ErrorInvocation(ZeroDTensorWithDtype(torch.int64), *args),
ErrorInvocation(ZeroDTensorWithDtype(torch.int32), *args),
ErrorInvocation(ZeroDTensorWithDtype(torch.bool), *args),
]

def _get_dtype_of_floating_point_op(input_dtype: int) -> int:
if input_dtype in (torch.float64, torch.bfloat16, torch.float16):
return input_dtype
return torch.float32

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇exp〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇sin〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇cos〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇sigmoid〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇reciprocal〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇log〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇log2〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇log1p〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇rsqrt〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇erf〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by())
def aten〇softplus〡dtype(self_rank: int, self_dtype: int, beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int:
assert self_dtype not in (torch.int64, torch.int32, torch.bool)
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by([0]))
def aten〇frobenius_norm〇dim〡dtype(self_rank: int, self_dtype: int, dim: List[int], keepdim: bool = False) -> int:
assert self_dtype not in (torch.int64, torch.int32, torch.bool)
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def prims〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇all〡dtype(self_rank: int, self_dtype: int) -> int:
Expand Down Expand Up @@ -1173,13 +1261,6 @@ def aten〇floor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int
def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int:
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)])

@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int:
if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16:
return self_dtype
else:
return torch.float32

# ==============================================================================
# Main
# ==============================================================================
Expand Down
34 changes: 0 additions & 34 deletions test/Dialect/Torch/refine-types-branch.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -117,37 +117,3 @@ func.func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.option
} : (!torch.int, !torch.bool, !torch.optional<tensor>) -> (!torch.optional<tensor>)
return %ret: !torch.optional<tensor>
}

// -----

// CHECK-LABEL: func.func @f
// CHECK: %[[ATEN:.*]] = torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
cf.br ^bb1(%cast: !torch.vtensor)
^bb1(%arg1: !torch.vtensor):
%1 = torch.aten.cos %arg1 : !torch.vtensor -> !torch.vtensor
return %1 : !torch.vtensor
}

// -----

// CHECK-LABEL: func.func @f
// CHECK: func.func private @callee
// CHECK-NEXT: torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
func.func @f() {
builtin.module {
func.func private @callee(%arg0: !torch.vtensor) {
%1 = torch.aten.cos %arg0 : !torch.vtensor -> !torch.vtensor
return
}
func.func @caller(%arg0: !torch.vtensor<*,f32>) {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
call @callee(%cast) : (!torch.vtensor) -> ()
return
}
}
return
}
Loading

0 comments on commit 8b8d071

Please sign in to comment.