Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hasDtype checks everywhere dtypes are used in decompositions #1750

Merged
merged 1 commit into from
Jan 3, 2023

Conversation

ramiro050
Copy link
Collaborator

There are several decompositions that assume the operands of the op have dtypes available; however, the only time dtypes are guaranteed to be present is when the graph has reached the backend contract. In general, every pass that happens before reaching the backend contract should not assume dtypes are available and should use hasDtype to check first.

This commit adds hasDtype checks to every decomposition that uses dtypes.

There are several decompositions that assume the operands of the op
have dtypes available; however, the only time dtypes are guaranteed to
be present is when the graph has reached the backend contract. In
general, every pass that happens before reaching the backend contract
should not assume dtypes are available and should use `hasDtype` to
check first.

This commit adds `hasDtype` checks to every decomposition that uses
dtypes.
@ramiro050 ramiro050 marked this pull request as ready for review December 22, 2022 22:56
@ramiro050 ramiro050 linked an issue Dec 22, 2022 that may be closed by this pull request
Copy link
Collaborator

@AmosLewis AmosLewis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vivekkhandelwal1
Copy link
Collaborator

vivekkhandelwal1 commented Dec 23, 2022

Hi @ramiro050, I have a doubt regarding this patch. In some places, you have directly used getOptionalDtype, while in some places you have used getDtype, and thrown an error if not present. What's the rationale behind doing this?

@ramiro050
Copy link
Collaborator Author

ramiro050 commented Dec 27, 2022

Hi @ramiro050, I have a doubt regarding this patch. In some places, you have directly used getOptionalDtype, while in some places you have used getDtype, and thrown an error if not present. What's the rationale behind doing this?

That's a good question. There are two scenarios that take place regarding the use of getDtype:

  1. The first scenario is the dtype of a tensor is used to create a new type for another tensor. In this case, it does not actually matter if the dtype exists or not. If the current dtype is unk, then the new tensor will get assigned that dtype too, and that is okay to do.
  2. The second scenario is the dtype is used to determine what IR to generate in the decomposition. For example, if we need to cast a tensor from one type to another, we need to know the actual value of dtype to know how to perform the cast correctly.

For the first scenario, I used the getOptionalDtype since knowing the dtype is not truly necessary. This allows some decompositions that don't depend conditionally on the value of dtype to still succeed. For example, there are several decompositions where the result tensor has the same dtype as the input tensor and dtypes only come about in the decomposition when creating the type for the result tensor. In this case, it does not matter if dtype exists or not for the decomposition to succeed.

The the second scenario, because it is important for dtype to have a value, I added error messages when the dtype does not exist.

Let me know if this makes sense to you and if the changes look good.

@li-plus
Copy link
Collaborator

li-plus commented Dec 28, 2022

Hi, I have a few more questions.

Why can't we determine the dtype before the decomposition happens? If we can, is it a better way to handle this issue? (like finalizing the dtype in RefineTypes)

In the second scenario, will there be unexpected decomposition errors in some combination of ops, when we can determine the dtype but we didn't? Like tanh + select.dim in this issue #1748.

@ramiro050
Copy link
Collaborator Author

Hi, I have a few more questions.

Why can't we determine the dtype before the decomposition happens? If we can, is it a better way to handle this issue? (like finalizing the dtype in RefineTypes)

Currently the dtype in your example is not being determined before decompose complex ops because we are in the process of migrating from using RefineTypes to using the dtype functions defined in abstract_interp_lib_gen.py. (See: https://github.com/llvm/torch-mlir/blob/main/docs/abstract_interp_lib.md). The order of the passes is:

createTorchShapeRefinementPipeline(pm);
createTorchDtypeRefinementPipeline(pm);
// Refine types in the program, which mainly means inferring dtypes of ops.
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
// Propagate to ABI return types the shape/dtype information discovered by
// the previous pass. Doing this is ABI-compatible for our backends.
pm.addPass(Torch::createRefinePublicReturnPass());
// This can fold away some branches given the information got from
// RefineTypes before doing maximize value sematics which only works with
// basic blocks.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
if (options.decompose) {
pm.addNestedPass<func::FuncOp>(
Torch::createDecomposeComplexOpsPass(options.backendLegalOps));

createTorchDtypeRefinementPipeline is the set of passes that use the dtype functions in abstract_interp_lib_gen.py.

Because your example does an aten.add.Tensor first, it needs to wait for RefineTypes to happen to get its result type. Then aten.tanh depends on createTorchDtypeRefinementPipeline, which already passed, so no result dtype is determined for aten.tanh, leaving aten.squeeze.dim with no dtypes in its inputs when DecomposeComplexOps is reached, leading to the error you're seeing.

However, these set of passes are meant to be run over and over again to avoid the catch-22 issue outlined above between the two dtype propagation passes (see: 57681f7). In your case, all that is needed is for the set of passes to be run one more time. What prevents this from happening is that DecomposeComplexOps currently assumes dtypes will be available when in reality there is no guarantee of that.

In the second scenario, will there be unexpected decomposition errors in some combination of ops, when we can determine the dtype but we didn't? Like tanh + select.dim in this issue #1748.

My patch fixes the example in your issue. I used your example to make sure things worked locally.

Copy link
Collaborator

@AmosLewis AmosLewis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I try to cherry pick this patch for the nod-ai/SHARK-Studio#338. But it fails.
Here is the torchscipt ir.
Please run torchscript IR to torch the backend pipeline over that, and it will fail.
torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=torch.aten.flatten.using_ints})' /tmp/gpt2_torch_raw_elide.mlir --mlir-print-ir-after-all > gpt2_tosa_ramiro.mlir
You will get this after refine type
%884 = torch.aten.tanh %883 : !torch.vtensor<[1,5,3072],f32> -> !torch.vtensor<[1,5,3072],unk>

@ramiro050
Copy link
Collaborator Author

I try to cherry pick this patch for the nod-ai/SHARK#338. But it fails. Here is the torchscipt ir. Please run torchscript IR to torch the backend pipeline over that, and it will fail. torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=torch.aten.flatten.using_ints})' /tmp/gpt2_torch_raw_elide.mlir --mlir-print-ir-after-all > gpt2_tosa_ramiro.mlir You will get this after refine type %884 = torch.aten.tanh %883 : !torch.vtensor<[1,5,3072],f32> -> !torch.vtensor<[1,5,3072],unk>

Please take a look at #1769, and let me know if that fixes it

@ramiro050
Copy link
Collaborator Author

@AmosLewis, while #1769 might fix the issue you're having with tanh, this PR also makes changes that are important to have to make decompositions more robust. If things look good to you here, I can go ahead and merge this.

Note: the commit message in #1769 explains why things still fail in GPT-2 when using this patch.

@AmosLewis
Copy link
Collaborator

AmosLewis commented Jan 3, 2023

@AmosLewis, while #1769 might fix the issue you're having with tanh, this PR also makes changes that are important to have to make decompositions more robust. If things look good to you here, I can go ahead and merge this.

Note: the commit message in #1769 explains why things still fail in GPT-2 when using this patch.

The tanh works. I think @vivekkhandelwal1 also mentioned there is a similar refinetype bug for AtenRsubScalarOp he worked on.

@ramiro050 ramiro050 merged commit d44bdd2 into llvm:main Jan 3, 2023
@ramiro050 ramiro050 deleted the check-dtype-decomp branch January 3, 2023 22:19
@ramiro050
Copy link
Collaborator Author

The tanh works. I think @vivekkhandelwal1 also mentioned there is a similar refinetype bug for AtenRsubScalarOp he worked on.

@vivekkhandelwal1, let me know if this PR and #1769 don't fix your issue.

@vivekkhandelwal1
Copy link
Collaborator

The tanh works. I think @vivekkhandelwal1 also mentioned there is a similar refinetype bug for AtenRsubScalarOp he worked on.

@vivekkhandelwal1, let me know if this PR and #1769 don't fix your issue.

Hi @ramiro050, the issue with AtenRsubScalarOp still exists. I have to add it to the RefineTypes.cpp to work.

@ramiro050
Copy link
Collaborator Author

Hi @ramiro050, the issue with AtenRsubScalarOp still exists. I have to add it to the RefineTypes.cpp to work.

I think the best way forward is to finish the transition to using dtype functions written in Python. I have made an issue to track the progress: #1807 This should once and for all fix these issues.

@vivekkhandelwal1
Copy link
Collaborator

Hi @ramiro050, the issue with AtenRsubScalarOp still exists. I have to add it to the RefineTypes.cpp to work.

I think the best way forward is to finish the transition to using dtype functions written in Python. I have made an issue to track the progress: #1807 This should once and for all fix these issues.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Some decomposition don't check for hasDtype before calling getDtype
4 participants