Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][TORCH][TOSA] Add e2e Tosa support for aten.as_strided #1742

Closed
wants to merge 1 commit into from

Conversation

AmosLewis
Copy link
Collaborator

@AmosLewis AmosLewis commented Dec 21, 2022

Got this bug in DistilGpt2 to Tosa nod-ai/SHARK-Studio#494

// -----
// CHECK-LABEL:   func.func @torch.aten.as_strided(
// CHECK-SAME:                                     %[[VAL_0:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,2],f32> {
// CHECK:           %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32>
// CHECK:           %[[VAL_2:.*]] = torch.constant.int 2
// CHECK:           %[[VAL_3:.*]] = torch.constant.int 1
// CHECK:           %[[VAL_4:.*]] = torch.constant.int 0
// CHECK:           %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK:           %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK:           %[[VAL_7:.*]] = "tosa.const"() {value = dense<{{\[\[}}0, 2, 1, 3]]> : tensor<1x4xi32>} : () -> tensor<1x4xi32>
// CHECK:           %[[VAL_8:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [1, 9, 1]} : (tensor<3x3xf32>) -> tensor<1x9x1xf32>
// CHECK:           %[[VAL_9:.*]] = "tosa.gather"(%[[VAL_8]], %[[VAL_7]]) : (tensor<1x9x1xf32>, tensor<1x4xi32>) -> tensor<1x4x1xf32>
// CHECK:           %[[VAL_10:.*]] = "tosa.reshape"(%[[VAL_9]]) {new_shape = [2, 2]} : (tensor<1x4x1xf32>) -> tensor<2x2xf32>
// CHECK:           %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<2x2xf32> -> !torch.vtensor<[2,2],f32>
// CHECK:           return %[[VAL_11]] : !torch.vtensor<[2,2],f32>
// CHECK:         }
func.func @torch.aten.as_strided(%arg0: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,2],f32> {
  %int2 = torch.constant.int 2
  %int1 = torch.constant.int 1
  %int0 = torch.constant.int 0
  %0 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
  %1 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
  %2 = torch.aten.as_strided %arg0, %0, %1, %int0 : !torch.vtensor<[3,3],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[2,2],f32>
  return %2 : !torch.vtensor<[2,2],f32>
}

@AmosLewis AmosLewis force-pushed the as_stride branch 2 times, most recently from bf5a1a2 to 89ec255 Compare December 22, 2022 00:35
@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Dec 22, 2022

distillgpt2_torch_delete_decompose_amax_selectint.mlir

Need this refind types patch to fix the decompose amax issue:
#1745

@AmosLewis

This comment was marked as outdated.

@AmosLewis AmosLewis marked this pull request as ready for review December 22, 2022 02:40
@AmosLewis AmosLewis changed the title [MLIR][TORCH] Add e2e support for aten.as_stride [MLIR][TORCH] Add e2e support for aten.as_strided Dec 22, 2022
@AmosLewis AmosLewis requested a review from ramiro050 December 22, 2022 06:13
Copy link
Collaborator

@ramiro050 ramiro050 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Github I don't see a lowering for as_strided in this PR. Is there a file missing?

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Dec 22, 2022

On Github I don't see a lowering for as_strided in this PR. Is there a file missing?

I am planning to add lower to tosa in the next PR. This is required for the distilgpt2 model. It might be a long patch. So I just added this e2e first so I can erase the torch.operator.* in mlir file. Before adding lower to tosa, I have to fix the aten.slice.Tensor to tosa first.

@ramiro050
Copy link
Collaborator

ramiro050 commented Dec 22, 2022

I am planning to add lower to tosa in the next PR. This is required for the distilgpt2 model.

In general, we should try to avoid adding code that does not get run by the e2e test suite, even if done temporarily. There are several parts in this PR currently don't get executed by the e2e suite.

It might be a long patch.

The changes here are not that much code. They should all be part of a single patch that adds e2e support. Not only does this help avoid having dead code in torch-mlir, but it makes reviewing easier, since the reviewer can see the declaration of the op, as well as the handling of its dtype, shape, and testing.

@AmosLewis
Copy link
Collaborator Author

I am planning to add lower to tosa in the next PR. This is required for the distilgpt2 model.

In general, we should try to avoid adding code that does not get run by the e2e test suite, even if done temporarily. There are several parts in this PR currently don't get executed by the e2e suite.

It might be a long patch.

The changes here are not that much code. They should all be part of a single patch that adds e2e support. Not only does this help avoid having dead code in torch-mlir, but it makes reviewing easier, since the reviewer can see the declaration of the op, as well as the handling of its dtype, shape, and testing.

Ok. I will continue iterating on this patch.

@AmosLewis AmosLewis marked this pull request as draft December 22, 2022 17:48
@AmosLewis AmosLewis force-pushed the as_stride branch 3 times, most recently from 1a43554 to b6caa9b Compare January 3, 2023 07:01
@AmosLewis AmosLewis marked this pull request as ready for review January 3, 2023 07:01
@AmosLewis AmosLewis changed the title [MLIR][TORCH] Add e2e support for aten.as_strided [MLIR][TORCH][TOSA] Add e2e Tosa support for aten.as_strided Jan 3, 2023
@AmosLewis AmosLewis force-pushed the as_stride branch 2 times, most recently from 5869097 to 9e031e6 Compare January 3, 2023 18:37
@AmosLewis AmosLewis force-pushed the as_stride branch 2 times, most recently from 38c1771 to 2a45694 Compare January 3, 2023 22:33
@AmosLewis AmosLewis requested a review from ramiro050 January 3, 2023 22:34
Copy link
Collaborator

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation looks okay to me. Ideally we'd avoid a tosa.gather as it tends to be a slow op for acceleration, but I'm not sure a loop of SLICEs would be significantly better in this case. Looking at the original network, is it doing effectively as_strided(as_strided(as_strided(as_strided(tensor))))?

test/Conversion/TorchToTosa/basic.mlir Outdated Show resolved Hide resolved
lib/Conversion/TorchToTosa/TorchToTosa.cpp Show resolved Hide resolved
])

def forward(self, x):
return torch.ops.aten.as_strided(x, (2, 2), (1, 2), 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this implementation work if you do

        return torch.ops.aten.as_strided(x.t(), (2, 2), (1, 2), 1)

I.e. pass the transpose of x as the argument

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add to this, I think adding support for this op in torch-mlir will be very tricky. The reason is that this op depends on knowledge about the storage used by the input tensor, and at the torch dialect level in torch-mlir there is no notion of storage. In the example I gave above, x.t() does not change the storage of the tensor, so PyTorch returns the same output as when x is passed. However, I expect torch-mlir will output a tensor as if x.t().contiguous() had been passed instead because it does not know that x.t() and x share the same storage.

Where is it that you're seeing this op used? Given the warning in the documentation of the op, I would expect that this op is not explicitly used in the definition of a model, but rather it is being generated by PyTorch when turning the model to JIT IR. If this is the case, then maybe we can find a way to fold it back.

Copy link
Collaborator Author

@AmosLewis AmosLewis Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With x.t(), it fail.

➜  torch-mlir git:(as_stride) ✗ python -m e2e_testing.main -c tosa -f "AsStridedStaticModule_basic"
FAIL - "AsStridedStaticModule_basic"

Unexpected outcome summary:

****** Failed tests - 1 tests
    FAIL - "AsStridedStaticModule_basic"

Summary:
    Failed: 1

Where is it that you're seeing this op used?
I got it from transformer distilgpt2 model.
Here is the python patch I use: distillgpt2.py
Here is the torchscript ID I got: distillgpt2_torchscript.mlir
Here is the debug torch mlir I got:distilgpt_lambda.mlir
Here is the final tosa file I generated for distilgpt2:distilgpt2_tosa.mlir

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is being generated by PyTorch when turning the model to JIT IR.
Not sure how to find if it is generated by JIT IR. And furthermore, if it is, how to fix it?

@AmosLewis AmosLewis force-pushed the as_stride branch 2 times, most recently from 3635d78 to f94c723 Compare January 4, 2023 23:37
@AmosLewis
Copy link
Collaborator Author

The implementation looks okay to me. Ideally we'd avoid a tosa.gather as it tends to be a slow op for acceleration, but I'm not sure a loop of SLICEs would be significantly better in this case. Looking at the original network, is it doing effectively as_strided(as_strided(as_strided(as_strided(tensor))))?

I thought of using tosa::SliceOp but didn't figure out a way. In the generated distillgpt2_torchscript.mlir, I didn't find as_strided(as_strided(as_strided(as_strided(tensor)))), but find the as_strided(view(tensor)), as_strided(view(tensor)), as_strided(view(tensor)). Could you post the link to the code you find it?

@eric-k256
Copy link
Collaborator

I was looking at your code, and this looks like a sequence of nested as_strided:

    %395 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
    %396 = torch.operator "aten.as_strided"(%393, %394, %395, %int0) : (!torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.tensor loc(#loc182)
    %397 = torch.prim.ListConstruct %int1, %int1, %int1024, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
    %398 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
    %399 = torch.operator "aten.as_strided"(%396, %397, %398, %int0) : (!torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.tensor loc(#loc183)
    %400 = torch.prim.ListConstruct %int1, %int1, %int128, %int1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
    %401 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
    %402 = torch.operator "aten.as_strided"(%399, %400, %401, %int0) : (!torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.tensor loc(#loc184)
    %403 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
    %404 = torch.prim.ListConstruct %int1048576, %int1048576, %int1024, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> loc(#loc)
    %405 = torch.operator "aten.as_strided"(%402, %403, %404, %int0) : (!torch.tensor, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.tensor loc(#loc185)

Looking at this, I agree with Ramiro's concerns, the warning in the documentation implies some behavior that may not be expressed in the captured MLIR. This is almost certainly some effect of the JIT tracing, and although we may be able to get it to work for this case, if we understand how the JIT tracing is generating this sequence, we may be able to map to a better set of operators.

@AmosLewis AmosLewis force-pushed the as_stride branch 3 times, most recently from edc78eb to 6d1b2f1 Compare January 16, 2023 06:04
@AmosLewis
Copy link
Collaborator Author

@ramiro050 @eric-k256 The as_strided is from the decomposition of torch.ops.aten.slice.Tensor when I use make_fx in python code afte the distilgpt2 model is imported. Just deleting slice decompose will get rid of the as_strided code.
https://github.com/pytorch/pytorch/blob/8f3600b966d896986e334b9a22c43e937ee0169d/torch/_decomp/decompositions.py#L663

@AmosLewis AmosLewis closed this Jan 24, 2023
@ramiro050
Copy link
Collaborator

@ramiro050 @eric-k256 The as_strided is from the decomposition of torch.ops.aten.slice.Tensor when I use make_fx in python code afte the distilgpt2 model is imported. Just deleting slice decompose will get rid of the as_strided code. https://github.com/pytorch/pytorch/blob/8f3600b966d896986e334b9a22c43e937ee0169d/torch/_decomp/decompositions.py#L663

Awesome! Thanks for looking into it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants