-
Notifications
You must be signed in to change notification settings - Fork 539
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dtype functions for floating point ops #1813
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1026,19 +1026,107 @@ def _get_invocations_for_op_with_tensor_arg_followed_by(*args): | |
dtype function instead of using this helper function. | ||
""" | ||
return [ | ||
Invocation(NonZeroDTensorWithDtype(torch.float32), *args), | ||
Invocation(NonZeroDTensorWithDtype(torch.float64), *args), | ||
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args), | ||
Invocation(NonZeroDTensorWithDtype(torch.int64), *args), | ||
Invocation(NonZeroDTensorWithDtype(torch.int32), *args), | ||
Invocation(NonZeroDTensorWithDtype(torch.bool), *args), | ||
Invocation(ZeroDTensorWithDtype(torch.float32), *args), | ||
Invocation(ZeroDTensorWithDtype(torch.float64), *args), | ||
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args), | ||
Invocation(ZeroDTensorWithDtype(torch.int64), *args), | ||
Invocation(ZeroDTensorWithDtype(torch.int32), *args), | ||
Invocation(ZeroDTensorWithDtype(torch.bool), *args), | ||
] | ||
Invocation(NonZeroDTensorWithDtype(torch.float32), *args), | ||
Invocation(NonZeroDTensorWithDtype(torch.float64), *args), | ||
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args), | ||
Invocation(NonZeroDTensorWithDtype(torch.int64), *args), | ||
Invocation(NonZeroDTensorWithDtype(torch.int32), *args), | ||
Invocation(NonZeroDTensorWithDtype(torch.bool), *args), | ||
Invocation(ZeroDTensorWithDtype(torch.float32), *args), | ||
Invocation(ZeroDTensorWithDtype(torch.float64), *args), | ||
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args), | ||
Invocation(ZeroDTensorWithDtype(torch.int64), *args), | ||
Invocation(ZeroDTensorWithDtype(torch.int32), *args), | ||
Invocation(ZeroDTensorWithDtype(torch.bool), *args), | ||
] | ||
|
||
def _get_invocations_for_fp_only_op_with_tensor_arg_followed_by(*args): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will be making a PR soon that improves the testing helper functions quite a bit, so if you're working on another set of ops, I would wait until that lands There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome! I haven't started working on the next task. Will wait for your PR before further development. |
||
"""Generate invocations for floating point only op.""" | ||
return [ | ||
Invocation(NonZeroDTensorWithDtype(torch.float32), *args), | ||
Invocation(NonZeroDTensorWithDtype(torch.float64), *args), | ||
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args), | ||
ErrorInvocation(NonZeroDTensorWithDtype(torch.int64), *args), | ||
ErrorInvocation(NonZeroDTensorWithDtype(torch.int32), *args), | ||
ErrorInvocation(NonZeroDTensorWithDtype(torch.bool), *args), | ||
Invocation(ZeroDTensorWithDtype(torch.float32), *args), | ||
Invocation(ZeroDTensorWithDtype(torch.float64), *args), | ||
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args), | ||
ErrorInvocation(ZeroDTensorWithDtype(torch.int64), *args), | ||
ErrorInvocation(ZeroDTensorWithDtype(torch.int32), *args), | ||
ErrorInvocation(ZeroDTensorWithDtype(torch.bool), *args), | ||
] | ||
|
||
def _get_dtype_of_floating_point_op(input_dtype: int) -> int: | ||
if input_dtype in (torch.float64, torch.bfloat16, torch.float16): | ||
return input_dtype | ||
return torch.float32 | ||
|
||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) | ||
def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int: | ||
return _get_dtype_of_floating_point_op(self_dtype) | ||
|
||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) | ||
def aten〇exp〡dtype(self_rank: int, self_dtype: int) -> int: | ||
return _get_dtype_of_floating_point_op(self_dtype) | ||
|
||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) | ||
def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int: | ||
return _get_dtype_of_floating_point_op(self_dtype) | ||
|
||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) | ||
def aten〇sin〡dtype(self_rank: int, self_dtype: int) -> int: | ||
return _get_dtype_of_floating_point_op(self_dtype) | ||
|
||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) | ||
def aten〇cos〡dtype(self_rank: int, self_dtype: int) -> int: | ||
return _get_dtype_of_floating_point_op(self_dtype) | ||
|
||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) | ||
def aten〇sigmoid〡dtype(self_rank: int, self_dtype: int) -> int: | ||
return _get_dtype_of_floating_point_op(self_dtype) | ||
|
||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) | ||
def aten〇reciprocal〡dtype(self_rank: int, self_dtype: int) -> int: | ||
return _get_dtype_of_floating_point_op(self_dtype) | ||
|
||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) | ||
def aten〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int: | ||
return _get_dtype_of_floating_point_op(self_dtype) | ||
|
||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) | ||
def aten〇log〡dtype(self_rank: int, self_dtype: int) -> int: | ||
return _get_dtype_of_floating_point_op(self_dtype) | ||
|
||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) | ||
def aten〇log2〡dtype(self_rank: int, self_dtype: int) -> int: | ||
return _get_dtype_of_floating_point_op(self_dtype) | ||
|
||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) | ||
def aten〇log1p〡dtype(self_rank: int, self_dtype: int) -> int: | ||
return _get_dtype_of_floating_point_op(self_dtype) | ||
|
||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) | ||
def aten〇rsqrt〡dtype(self_rank: int, self_dtype: int) -> int: | ||
return _get_dtype_of_floating_point_op(self_dtype) | ||
|
||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) | ||
def aten〇erf〡dtype(self_rank: int, self_dtype: int) -> int: | ||
return _get_dtype_of_floating_point_op(self_dtype) | ||
|
||
@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by()) | ||
def aten〇softplus〡dtype(self_rank: int, self_dtype: int, beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int: | ||
assert self_dtype not in (torch.int64, torch.int32, torch.bool) | ||
return _get_dtype_of_floating_point_op(self_dtype) | ||
|
||
@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by([0])) | ||
def aten〇frobenius_norm〇dim〡dtype(self_rank: int, self_dtype: int, dim: List[int], keepdim: bool = False) -> int: | ||
assert self_dtype not in (torch.int64, torch.int32, torch.bool) | ||
return _get_dtype_of_floating_point_op(self_dtype) | ||
|
||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) | ||
def prims〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int: | ||
return _get_dtype_of_floating_point_op(self_dtype) | ||
|
||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) | ||
def aten〇all〡dtype(self_rank: int, self_dtype: int) -> int: | ||
|
@@ -1167,13 +1255,6 @@ def aten〇floor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int | |
def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int: | ||
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) | ||
|
||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) | ||
def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int: | ||
if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16: | ||
return self_dtype | ||
else: | ||
return torch.float32 | ||
|
||
# ============================================================================== | ||
# Main | ||
# ============================================================================== | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine, thanks!