-
Notifications
You must be signed in to change notification settings - Fork 514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrite dtype functions for softmax ops in Python #1987
base: dtype-functions-staging
Are you sure you want to change the base?
Conversation
@ramiro050 these are all one-liners, do you think they require tests? |
Yes, every dtype function should be tested. What's nice is that the testing code is also a one-liner (two-liner tops) for these ops 🙂 See below for examples of testing ops with one tensor inputs and two tensor inputs: torch-mlir/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py Lines 1305 to 1312 in d3a49fd
|
The CI is failing on the refine-types lit tests. Since we are slowly removing logic from |
fd8d8af
to
54d7f28
Compare
@ramiro050 It's failing only on dynamo, is this a known issue? Here's the IR before failure:
|
@@ -2232,6 +2232,36 @@ def aten〇native_batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r | |||
assert is_float_dtype(input_dtype) | |||
return input_dtype, input_dtype, input_dtype | |||
|
|||
@check_dtype_function([ | |||
Invocation(TensorOfShape(2, 3, 4, dtype=torch.float32), dim=0, half_to_float=False)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use the helper functions for testing? These create invocations that make very thorough checks of the dtype functions. Same comment applies to the rest of the functions. See:
torch-mlir/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py
Lines 1305 to 1312 in d3a49fd
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + | |
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) | |
def aten〇eq〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: | |
return torch.bool | |
@check_dtype_function(_check_two_tensor_op()) | |
def aten〇eq〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: | |
return torch.bool |
@@ -619,19 +614,17 @@ void TypeAnalysis::visitOperation(Operation *op, | |||
AtenBitwiseNotOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, | |||
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopyOp, AtenCumsumOp, | |||
AtenLayerNormOp, AtenClampOp, AtenClampMinOp, AtenClampMaxOp, | |||
AtenNegOp, AtenFloorOp, Aten_SoftmaxBackwardDataOp, AtenDropoutOp, | |||
AtenTanhBackwardOp, AtenHardtanhBackwardOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is already a patch that will handle the ops here: #1895. Can you undo these changes to avoid conflicts?
@check_dtype_function([ | ||
Invocation(TensorOfShape(2, 3, 4, dtype=torch.float32), TensorOfShape(2, 3, 4, dtype=torch.float32), dim=0, input_dtype=torch.float32)]) | ||
def aten〇_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int], dim: int, input_dtype: int) -> int: | ||
return grad_output_rank_dtype[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the reason the test is failing is because there is currently no support for indexing tuples in dtype functions. Can you change this to grad_output_rank, grad_output_dtype = grad_output_rank_dtype
like the other functions here to see if it fixes it?
Hey George, any updates on the dtype functions? |
No description provided.