Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite dtype functions for softmax ops in Python #1987

Draft
wants to merge 1 commit into
base: dtype-functions-staging
Choose a base branch
from

Conversation

gpetters94
Copy link
Collaborator

No description provided.

@gpetters94 gpetters94 requested a review from ramiro050 March 29, 2023 23:25
@gpetters94
Copy link
Collaborator Author

gpetters94 commented Mar 29, 2023

@ramiro050 these are all one-liners, do you think they require tests?

@ramiro050
Copy link
Collaborator

@ramiro050 these are all one-liners, do you think they require tests?

Yes, every dtype function should be tested. What's nice is that the testing code is also a one-liner (two-liner tops) for these ops 🙂

See below for examples of testing ops with one tensor inputs and two tensor inputs:

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
def aten〇eq〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
return torch.bool
@check_dtype_function(_check_two_tensor_op())
def aten〇eq〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
return torch.bool

@ramiro050
Copy link
Collaborator

The CI is failing on the refine-types lit tests. Since we are slowly removing logic from RefineTypes, you can just delete the failing MLIR test

@gpetters94
Copy link
Collaborator Author

@ramiro050 It's failing only on dynamo, is this a known issue? Here's the IR before failure:

func.func @forward(%arg0: !torch.vtensor<[3,2,4],f32>, %arg1: !torch.vtensor<[3,2,4],f32>) -> !torch.vtensor<[3,2,4],unk> {
  %int1 = torch.constant.int 1
  %none = torch.constant.none
  %int6 = torch.constant.int 6
  %0 = torch.aten.clone %arg0, %none : !torch.vtensor<[3,2,4],f32>, !torch.none -> !torch.vtensor<[3,2,4],f32>
  %1 = torch.aten.clone %arg1, %none : !torch.vtensor<[3,2,4],f32>, !torch.none -> !torch.vtensor<[3,2,4],f32>
  %2 = torch.aten._softmax_backward_data %0, %1, %int1, %int6 : !torch.vtensor<[3,2,4],f32>, !torch.vtensor<[3,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,2,4],unk>
  return %2 : !torch.vtensor<[3,2,4],unk>
}

@@ -2232,6 +2232,36 @@ def aten〇native_batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r
assert is_float_dtype(input_dtype)
return input_dtype, input_dtype, input_dtype

@check_dtype_function([
Invocation(TensorOfShape(2, 3, 4, dtype=torch.float32), dim=0, half_to_float=False)])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use the helper functions for testing? These create invocations that make very thorough checks of the dtype functions. Same comment applies to the rest of the functions. See:

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
def aten〇eq〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
return torch.bool
@check_dtype_function(_check_two_tensor_op())
def aten〇eq〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
return torch.bool

@@ -619,19 +614,17 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenBitwiseNotOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopyOp, AtenCumsumOp,
AtenLayerNormOp, AtenClampOp, AtenClampMinOp, AtenClampMaxOp,
AtenNegOp, AtenFloorOp, Aten_SoftmaxBackwardDataOp, AtenDropoutOp,
AtenTanhBackwardOp, AtenHardtanhBackwardOp,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already a patch that will handle the ops here: #1895. Can you undo these changes to avoid conflicts?

@check_dtype_function([
Invocation(TensorOfShape(2, 3, 4, dtype=torch.float32), TensorOfShape(2, 3, 4, dtype=torch.float32), dim=0, input_dtype=torch.float32)])
def aten〇_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int], dim: int, input_dtype: int) -> int:
return grad_output_rank_dtype[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason the test is failing is because there is currently no support for indexing tuples in dtype functions. Can you change this to grad_output_rank, grad_output_dtype = grad_output_rank_dtype like the other functions here to see if it fixes it?

@ramiro050
Copy link
Collaborator

Hey George, any updates on the dtype functions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants