Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch to TOSA conversion fails to legalize 'torch.constant.int' #961

Closed
Svoch opened this issue Jun 22, 2022 · 11 comments
Closed

Torch to TOSA conversion fails to legalize 'torch.constant.int' #961

Svoch opened this issue Jun 22, 2022 · 11 comments

Comments

@Svoch
Copy link

Svoch commented Jun 22, 2022

I am trying to compile a portion of a PyTorch Self-Attention module down to TOSA backend and am hitting an error on legalizing the torch.contant.int Op in TOSA conversion pass. The issue raises only when output type is set to torch_mlir.OutputType.TOSA in Torch-MLIR compile API. The conversion to LinAlg Dialect and further down to backend works fine. However Torch to TOSA conversion is failing.

Error log

Exception: 
Lowering Torch Backend IR -> TOSA Backend IR failed with the following diagnostics:
error: failed to legalize operation 'torch.constant.int'
note: see current operation: %5 = "torch.constant.int"() {value = 0 : i64} : () -> !torch.int

Steps to reproduce

The script to reproduce the error is up in this draft PR on a local fork. The error can be reproduced using the code snippet below with the module definition and torch_mlir.compile() API call:

class AttentionScores(torch.nn.Module):
    def __init__(self, embedding_dim, num_heads):
        super(AttentionScores, self).__init__()
        self.query = nn.Linear(embedding_dim, num_heads)
        self.key = nn.Linear(embedding_dim, num_heads)

    def forward(self, inputs):
        query = self.query(inputs)
        key = self.key(inputs)
        scores = torch.matmul(query, key.transpose(0, 1))
        return scores

attention_scores = AttentionScores(embedding_dim=10, num_heads=2)
inputs = torch.rand(5, 10)
tosa_module = torch_mlir.compile(attention_scores, inputs, output_type=torch_mlir.OutputType.TOSA)

This issue is potentially relevant to what @nithinsubbiah, @rdadolf and I are seeing in #910. I was also able to reproduce the error by simplifying the module above to a single Transpose Op (i.e. torch.Tensor.transpose in forward-propagate method).

cc @sjarus @powderluv @silvasean - wonder if you have seen this or have any insight on what might have been going wrong here.

@sjarus
Copy link
Collaborator

sjarus commented Jun 22, 2022

I've encountered this already @Svoch - it also impacts MobilenetsV3 . Working on a fix internally but am getting some BERT ones out first.

@YellowHCH
Copy link

I just convert ConstantIntOp to arith while converting Bert to tosa, as a intermediate result.

@Svoch
Copy link
Author

Svoch commented Jun 28, 2022

Below is the IR I get lowering the model above to Torch Dialect. If I've understood correctly, even though rewriters for AtenMmOp and AtenLinearOp do exist in TorchToTosa lowering pass, there is no lowering pattern for AtenTransposeIntOp. Seems like this is the underlying problem here. Does it make sense @sjarus?

module attributes {torch.debug_module_name = "AttentionScores"} {
  func.func @forward(%arg0: !torch.vtensor<[5,10],f32>) -> !torch.vtensor<[5,5],f32> {
    %0 = torch.vtensor.literal(dense<[0.143874153, 0.230392322]> : tensor<2xf32>) : !torch.vtensor<[2],f32>
    %1 = torch.vtensor.literal(dense<[[0.226313293, -0.279624343, -0.211782753, 2.701980e-01, 0.0410803184, 0.25729695, -0.00262214779, -0.0828355625, 0.145104617, -0.0266586915], [0.239182726, 0.00810842216, -0.00369983795, -0.132520735, -0.254919976, 0.0812480971, -0.196122929, 0.10878253, 0.158111736, 0.306294829]]> : tensor<2x10xf32>) : !torch.vtensor<[2,10],f32>
    %2 = torch.vtensor.literal(dense<[0.249187589, 0.0799354762]> : tensor<2xf32>) : !torch.vtensor<[2],f32>
    %3 = torch.vtensor.literal(dense<[[2.573700e-01, 0.215394989, 0.238999024, 0.163943127, 0.212515891, 0.231857046, -0.28136012, -0.118400358, -0.217035949, -0.0496219955], [-0.3136262, 0.105287395, -0.037995059, -0.129876241, -0.142800108, -0.208000481, -0.0777831152, 0.144313246, -0.086798042, 0.255681425]]> : tensor<2x10xf32>) : !torch.vtensor<[2,10],f32>
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %4 = torch.aten.linear %arg0, %3, %2 : !torch.vtensor<[5,10],f32>, !torch.vtensor<[2,10],f32>, !torch.vtensor<[2],f32> -> !torch.vtensor<[5,2],f32>
    %5 = torch.aten.linear %arg0, %1, %0 : !torch.vtensor<[5,10],f32>, !torch.vtensor<[2,10],f32>, !torch.vtensor<[2],f32> -> !torch.vtensor<[5,2],f32>
    %6 = torch.aten.transpose.int %5, %int0, %int1 : !torch.vtensor<[5,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,5],f32>
    %7 = torch.aten.mm %4, %6 : !torch.vtensor<[5,2],f32>, !torch.vtensor<[2,5],f32> -> !torch.vtensor<[5,5],f32>
    return %7 : !torch.vtensor<[5,5],f32>
  }
}

@sjarus
Copy link
Collaborator

sjarus commented Jun 28, 2022

The IR is really helpful to test against what I have and see if it legalizes right. I'll check today after morning meetings and post an update.

@Shukla-Gaurav
Copy link
Collaborator

Shukla-Gaurav commented Jul 6, 2022

I am encountering the same issue with this patch #862 for ResNet18 static model. @sjarus
error: failed to legalize operation 'torch.constant.int'

@sjarus
Copy link
Collaborator

sjarus commented Jul 6, 2022

Just pushed #1017 on this .

@AmosLewis
Copy link
Collaborator

I encounter the same issue to lower the huggingface gpt2. https://gist.github.com/AmosLewis/9b929414d5677afda3528122f92bbc73 @sjarus
error: failed to legalize operation 'torch.constant.int'

@sjarus
Copy link
Collaborator

sjarus commented Sep 22, 2022

torch.constant.int is a known missing conversion.

qedawkins pushed a commit to nod-ai/torch-mlir that referenced this issue Oct 3, 2022
* Change how we get executable path

Signed-off-by: Michael Holman <[email protected]>

* fallback to kExecPath

Signed-off-by: Michael Holman <[email protected]>

* emit path in warning

Signed-off-by: Michael Holman <[email protected]>

Co-authored-by: Alexandre Eichenberger <[email protected]>
Co-authored-by: gongsu832 <[email protected]>
@silvasean
Copy link
Contributor

Was this fixed?

@Svoch
Copy link
Author

Svoch commented Oct 7, 2022

@silvasean - Yes. with the Torch to TOSA conversion of the Transpose Op merged, this issue can be marked as fixed.

Please note that the symptom, i.e. error: failed to legalize operation 'torch.constant.int' is common for most of the Ops without a TOSA lowering in Torch-MLIR, since they usually will leave dangling attributes (such as axis integers) after the Torch to TOSA conversion.

@AmosLewis
Copy link
Collaborator

AmosLewis commented Dec 12, 2022

The torch.constant.int error just means there are aten ops that use this torch.constant.int as operand haven’t been lowered successfully by your lowering code. You need to find the op that is not lower successfully in the IR of debug info. And understand each line of your lowering code that is related to this op and come up with a new plan.

The error will come again and again on each op we try to lower until we lower it successfully. As you can see in the comment, I find this error many times when I started to work on gpt. This error will disappear when you lower your own ops to tosa(or other dialects stablehlo/linalg/tmtensor) correctly.

torch-mlir-opt -convert-torch-to-tosa /tmp/aten_as_tride.mlir -mlir-print-ir-after-all -mlir-disable-threading --debug  --mlir-print-ir-before-all

This is the command that you might need to get more debug info. you just need to replace the /tmp/aten_as_tride.mlir with op.mlir file you manually created. You can take my where.mlir file and command in comments as examples. Here is the link https://gist.github.com/AmosLewis/32847885f8b3ff27b7ef6564154fec59

For those who worked on tosa, here is the relationship of the 2 tosa-related flags for torch-mir-opt you need to understand before diving into debugging:

-pass-pipeline='torch-backend-to-tosa-backend-pipeline'  == "-convert-torch-to-tosa"+ some other clear/standard conversion pass(like clear the torch.constant.int for aten ops that successfully lowered to tosa)

-pass-pipeline='torch-backend-to-tosa-backend-pipeline'  will call this line 100, the whole function
createTorchBackendToTosaBackendPipeline( OpPassManager &pm)

void TorchConversion::createTorchBackendToTosaBackendPipeline(

-convert-torch-to-tosa will only call this line 102,

pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
. which will call the convention function createConvertTorchToTosaPass() you added for your ops in TorchToTosa.cpp
mlir::torch::createConvertTorchToTosaPass() {
. Which will call ConvertTorchToTosa::runOnOperation()
void runOnOperation() override {
. In this function , this is where the MatchAndRewrite pattern lowering ops code we added usually started.

The torch.constant.int to tosa type should be clean around this line 113 if line 102 -convert-torch-to-tosa doing well.

pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
.as the comment in line 111 has explained.

And in each matchAndRewirte pattern, each aten ops has a corresponding Adaptor op. The adaptor is the mlir inside version of the aten ops. For example, for a where.mlir file, torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1> ), the arg0 if you use atenop.getSelf().dump(), you will get torch version tensor !torch.vtensor<[1,1,5,5],i1>. But if you use adaptor.getSelf().dump(), you will get tensor<1x1x5x5xi1>.

Those useful op helper function like getSelf(), you can find them in you own building directory, build/tools/torc-mlir/include/torch-mlir/dialect/torch/IR/TorchOps.h.inc, their implementation is in build/tools/torc-mlir/include/torch-mlir/dialect/torch/IR/TorchOps.cpp.inc. This is automatic generated by tabelgen(.td file)of of mlir. The tablegen file location is at the similar dir structure of torch_mlir source code https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td. In this td file, you will find detail types of each aten ops, which will be very useful when you come up you new lowering plans

And to play with adaptor's types, which is mlir internal type, like etc, you will need the function in external/llvm-project https://github.com/llvm/llvm-project/blob/798fa4b415eea55c868ae42b874083cb9886991e/mlir/include/mlir/IR/Types.h and https://github.com/llvm/llvm-project/blob/798fa4b415eea55c868ae42b874083cb9886991e/mlir/include/mlir/IR/BuiltinTypes.h

We will have to go deep and read these codes, understand their design structure, and get familiar with them. Otherwise, nothing we can successfully debug. These codes are like the raw food for a cooker. C++ and python is our cooking tools. Our work is to come up with a recipe(lowering plan) and use the cooking tools to cook(implement/debug) it with this raw food.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants