-
Notifications
You must be signed in to change notification settings - Fork 520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][TORCH][TOSA] Add e2e Tosa support for aten.as_strided #1742
Conversation
bf5a1a2
to
89ec255
Compare
distillgpt2_torch_delete_decompose_amax_selectint.mlir Need this refind types patch to fix the decompose amax issue: |
This comment was marked as outdated.
This comment was marked as outdated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On Github I don't see a lowering for as_strided
in this PR. Is there a file missing?
I am planning to add lower to tosa in the next PR. This is required for the distilgpt2 model. It might be a long patch. So I just added this e2e first so I can erase the torch.operator.* in mlir file. Before adding lower to tosa, I have to fix the aten.slice.Tensor to tosa first. |
In general, we should try to avoid adding code that does not get run by the e2e test suite, even if done temporarily. There are several parts in this PR currently don't get executed by the e2e suite.
The changes here are not that much code. They should all be part of a single patch that adds e2e support. Not only does this help avoid having dead code in torch-mlir, but it makes reviewing easier, since the reviewer can see the declaration of the op, as well as the handling of its dtype, shape, and testing. |
Ok. I will continue iterating on this patch. |
1a43554
to
b6caa9b
Compare
5869097
to
9e031e6
Compare
38c1771
to
2a45694
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation looks okay to me. Ideally we'd avoid a tosa.gather as it tends to be a slow op for acceleration, but I'm not sure a loop of SLICEs would be significantly better in this case. Looking at the original network, is it doing effectively as_strided(as_strided(as_strided(as_strided(tensor))))?
]) | ||
|
||
def forward(self, x): | ||
return torch.ops.aten.as_strided(x, (2, 2), (1, 2), 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this implementation work if you do
return torch.ops.aten.as_strided(x.t(), (2, 2), (1, 2), 1)
I.e. pass the transpose of x
as the argument
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To add to this, I think adding support for this op in torch-mlir will be very tricky. The reason is that this op depends on knowledge about the storage used by the input tensor, and at the torch dialect level in torch-mlir there is no notion of storage. In the example I gave above, x.t()
does not change the storage of the tensor, so PyTorch returns the same output as when x
is passed. However, I expect torch-mlir will output a tensor as if x.t().contiguous()
had been passed instead because it does not know that x.t()
and x
share the same storage.
Where is it that you're seeing this op used? Given the warning in the documentation of the op, I would expect that this op is not explicitly used in the definition of a model, but rather it is being generated by PyTorch when turning the model to JIT IR. If this is the case, then maybe we can find a way to fold it back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With x.t()
, it fail.
➜ torch-mlir git:(as_stride) ✗ python -m e2e_testing.main -c tosa -f "AsStridedStaticModule_basic"
FAIL - "AsStridedStaticModule_basic"
Unexpected outcome summary:
****** Failed tests - 1 tests
FAIL - "AsStridedStaticModule_basic"
Summary:
Failed: 1
Where is it that you're seeing this op used?
I got it from transformer distilgpt2 model.
Here is the python patch I use: distillgpt2.py
Here is the torchscript ID I got: distillgpt2_torchscript.mlir
Here is the debug torch mlir I got:distilgpt_lambda.mlir
Here is the final tosa file I generated for distilgpt2:distilgpt2_tosa.mlir
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is being generated by PyTorch when turning the model to JIT IR.
Not sure how to find if it is generated by JIT IR. And furthermore, if it is, how to fix it?
3635d78
to
f94c723
Compare
I thought of using tosa::SliceOp but didn't figure out a way. In the generated distillgpt2_torchscript.mlir, I didn't find |
I was looking at your code, and this looks like a sequence of nested as_strided:
Looking at this, I agree with Ramiro's concerns, the warning in the documentation implies some behavior that may not be expressed in the captured MLIR. This is almost certainly some effect of the JIT tracing, and although we may be able to get it to work for this case, if we understand how the JIT tracing is generating this sequence, we may be able to map to a better set of operators. |
edc78eb
to
6d1b2f1
Compare
@ramiro050 @eric-k256 The as_strided is from the decomposition of |
Awesome! Thanks for looking into it |
Got this bug in DistilGpt2 to Tosa nod-ai/SHARK-Studio#494