Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cpu] failed to legalize unresolved materialization from ('i64') to 'index' that remained live after conversion #18899

Open
pdhirajkumarprasad opened this issue Oct 25, 2024 · 18 comments
Assignees
Labels
bug 🐞 Something isn't working

Comments

@pdhirajkumarprasad
Copy link

What happened?

For the give IR ( IREE compiler version 20241024.1057 @ 9c5b57a )

module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1>   attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
    %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %2 = torch.operator "onnx.Pad"(%arg1, %arg4, %1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[?,?],si64>, !torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[?,?],si64> 
    %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %4 = torch.operator "onnx.Gather"(%arg5, %3) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %5 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %6 = torch.operator "onnx.Div"(%4, %5) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %9 = torch.operator "onnx.Unsqueeze"(%6, %8) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %10 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<8> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %11 = torch.operator "onnx.Concat"(%7, %9, %10) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3],si64> 
    %12 = torch.operator "onnx.Reshape"(%2, %11) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[?,?],si64>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],si64> 
    %13 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %14 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %15 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %16 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %17 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %18 = torch.operator "onnx.Unsqueeze"(%12, %17) : (!torch.vtensor<[?,?,?],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,?,1],si64> 
    %19 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %20 = torch.operator "onnx.Unsqueeze"(%arg3, %19) : (!torch.vtensor<[?,?,?],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,1,?],si64> 
    %21 = torch.operator "onnx.Cast"(%18) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[?,?,?,1],si64>) -> !torch.vtensor<[?,?,?,1],i1> 
    %22 = torch.operator "onnx.Cast"(%20) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[?,?,1,?],si64>) -> !torch.vtensor<[?,?,1,?],i1> 
    %23 = torch.operator "onnx.And"(%21, %22) : (!torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1>) -> !torch.vtensor<[?,?,?,?],i1> 
    %24 = torch.operator "onnx.Cast"(%23) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?,?],i1> 
    %25 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1x1x128x384xi1>} : () -> !torch.vtensor<[1,1,128,384],i1> 
    %26 = torch.operator "onnx.And"(%24, %25) : (!torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1>) -> !torch.vtensor<[?,?,128,384],i1> 
    return %26 : !torch.vtensor<[?,?,128,384],i1>
  }
}

Getting error as

<unknown>:0: error: failed to legalize unresolved materialization from ('i64') to 'index' that remained live after conversion
<unknown>:0: note: see current operation: %6 = "builtin.unrealized_conversion_cast"(%5) : (i64) -> index

Steps to reproduce your issue

Command

iree-compile model.torch_onnx.mlir --iree-hal-target-backends=llvm-cpu -o comp.vmfb --iree-llvmcpu-target-cpu=host

Detail log:

dump.log

What component(s) does this issue relate to?

Compiler

Version information

No response

Additional context

No response

@MaheshRavishankar
Copy link
Contributor

@pashu123 looks like the scf.forall resolution not kicking in somewhere?

@pashu123
Copy link
Contributor

@pashu123 looks like the scf.forall resolution not kicking in somewhere?

It's forming a dispatch with dag_root https://gist.github.com/pashu123/12dd8d3771a1a5cfd99a52ee5b001a98#file-module_main_graph-async_dispatch_3-mlir-L5 looking into it.

@MaheshRavishankar
Copy link
Contributor

Try top of tree. Quinn fixes a few builtins

@pashu123
Copy link
Contributor

Try top of tree. Quinn fixes a few builtins

I am trying the top of the main. It's not coming from the BuiltIns but forming during the dispatch region creation. Full dump: https://gist.github.com/pashu123/fb9a9d29b9f199d6f10bfb3c2d55ed49 Line 47735

@pashu123
Copy link
Contributor

%27 = flow.dispatch.region[%23, %24, %25, %26] -> (tensor<?x?x128x384xi1>{%20, %21}) {
  %30 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%22 : tensor<?x?x128x384xi1>) {
  ^bb0(%out: i1):
    linalg.yield %false : i1
  } -> tensor<?x?x128x384xi1>
  flow.return %30 : tensor<?x?x128x384xi1>
} count(%arg8: index, %arg9: index, %arg10: index, %arg11: index) -> (index, index, index) {
  %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg8, %arg9, %arg10, %arg11
  flow.return %x, %y, %z : index, index, index
}

wrapOpInDispatchRegion(RewriterBase &rewriter, Operation *op) {
It's forming here.

@pashu123
Copy link
Contributor

So, rather than creating count_from_dag_root, should it create a count_from_slice here? @MaheshRavishankar

@MaheshRavishankar
Copy link
Contributor

Yes! I wonder why that is happening.

@Max191
Copy link
Contributor

Max191 commented Oct 29, 2024

Yes! I wonder why that is happening.

It is choosing count_from_dag_root here because the size of the tensor is data dependent. Here is the dump before FormDispatchRegions (see the op at the end of the dump):

// -----// IR Dump After FormScalarDispatchesPass (iree-dispatch-creation-form-scalar-dispatches) //----- //
util.func public @main_graph$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view, %arg5: !hal.buffer_view, %arg6: !hal.fence, %arg7: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
  %false = arith.constant false
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %c0_i64 = arith.constant 0 : i64
  %cst = arith.constant dense<8> : tensor<i64>
  %cst_0 = arith.constant dense<0> : tensor<3xi64>
  %0 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
  %1 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[1] : index
  %2 = hal.tensor.import wait(%arg6) => %arg4 : !hal.buffer_view -> tensor<4xi64>
  %3 = hal.tensor.import wait(%arg6) => %arg5 : !hal.buffer_view -> tensor<2xi64>
  %extracted_slice = tensor.extract_slice %2[0] [1] [1] : tensor<4xi64> to tensor<i64>
  %expanded = tensor.expand_shape %extracted_slice [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted = tensor.extract %expanded[%c0] : tensor<1xi64>
  %extracted_slice_1 = tensor.extract_slice %2[1] [1] [1] : tensor<4xi64> to tensor<i64>
  %expanded_2 = tensor.expand_shape %extracted_slice_1 [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted_3 = tensor.extract %expanded_2[%c0] : tensor<1xi64>
  %extracted_slice_4 = tensor.extract_slice %2[2] [1] [1] : tensor<4xi64> to tensor<i64>
  %expanded_5 = tensor.expand_shape %extracted_slice_4 [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted_6 = tensor.extract %expanded_5[%c0] : tensor<1xi64>
  %extracted_slice_7 = tensor.extract_slice %2[3] [1] [1] : tensor<4xi64> to tensor<i64>
  %expanded_8 = tensor.expand_shape %extracted_slice_7 [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted_9 = tensor.extract %expanded_8[%c0] : tensor<1xi64>
  %4 = arith.index_cast %extracted : i64 to index
  %5 = arith.index_cast %extracted_6 : i64 to index
  %6 = arith.index_cast %extracted_3 : i64 to index
  %7 = arith.index_cast %extracted_9 : i64 to index
  %8 = tensor.empty() : tensor<i64>
  %9 = flow.dispatch.region -> (tensor<i64>) {
    %26 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%8 : tensor<i64>) {
    ^bb0(%out: i64):
      %extracted_20 = tensor.extract %3[%c1] : tensor<2xi64>
      %27 = arith.divsi %extracted_20, %c0_i64 : i64
      linalg.yield %27 : i64
    } -> tensor<i64>
    flow.return %26 : tensor<i64>
  } count() -> (index, index, index) {
    %c1_20 = arith.constant 1 : index
    flow.return %c1_20, %c1_20, %c1_20 : index, index, index
  }
  %inserted_slice = tensor.insert_slice %9 into %cst_0[1] [1] [1] : tensor<i64> into tensor<3xi64>
  %inserted_slice_10 = tensor.insert_slice %cst into %inserted_slice[2] [1] [1] : tensor<i64> into tensor<3xi64>
  %extracted_slice_11 = tensor.extract_slice %inserted_slice_10[0] [1] [1] : tensor<3xi64> to tensor<i64>
  %expanded_12 = tensor.expand_shape %extracted_slice_11 [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted_13 = tensor.extract %expanded_12[%c0] : tensor<1xi64>
  %10 = arith.cmpi eq, %extracted_13, %c0_i64 : i64
  %11 = arith.addi %4, %5 : index
  %12 = arith.addi %11, %0 : index
  %13 = arith.index_cast %12 : index to i64
  %14 = flow.dispatch.region -> (tensor<i64>) {
    %26 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%8 : tensor<i64>) {
    ^bb0(%out: i64):
      %27 = arith.select %10, %13, %extracted_13 : i64
      linalg.yield %27 : i64
    } -> tensor<i64>
    flow.return %26 : tensor<i64>
  } count() -> (index, index, index) {
    %c1_20 = arith.constant 1 : index
    flow.return %c1_20, %c1_20, %c1_20 : index, index, index
  }
  %extracted_14 = tensor.extract %14[] : tensor<i64>
  %extracted_slice_15 = tensor.extract_slice %inserted_slice_10[1] [1] [1] : tensor<3xi64> to tensor<i64>
  %expanded_16 = tensor.expand_shape %extracted_slice_15 [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted_17 = tensor.extract %expanded_16[%c0] : tensor<1xi64>
  %15 = arith.cmpi eq, %extracted_17, %c0_i64 : i64
  %16 = arith.addi %6, %7 : index
  %17 = arith.addi %16, %1 : index
  %18 = arith.index_cast %17 : index to i64
  %19 = flow.dispatch.region -> (tensor<i64>) {
    %26 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%8 : tensor<i64>) {
    ^bb0(%out: i64):
      %27 = arith.select %15, %18, %extracted_17 : i64
      linalg.yield %27 : i64
    } -> tensor<i64>
    flow.return %26 : tensor<i64>
  } count() -> (index, index, index) {
    %c1_20 = arith.constant 1 : index
    flow.return %c1_20, %c1_20, %c1_20 : index, index, index
  }
  %extracted_18 = tensor.extract %19[] : tensor<i64>
  %20 = arith.index_cast %extracted_14 : i64 to index
  %21 = arith.index_cast %extracted_18 : i64 to index

  // The init operand here is dependent on `%extracted_14` and `%extracted_18`
  %22 = tensor.empty(%20, %21) : tensor<?x?x128x384xi1>
  %23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%22 : tensor<?x?x128x384xi1>) {
  ^bb0(%out: i1):
    linalg.yield %false : i1
  } -> tensor<?x?x128x384xi1>
  %24 = hal.tensor.barrier join(%23 : tensor<?x?x128x384xi1>) => %arg7 : !hal.fence
  %dim = tensor.dim %24, %c0 : tensor<?x?x128x384xi1>
  %dim_19 = tensor.dim %24, %c1 : tensor<?x?x128x384xi1>
  %25 = hal.tensor.export %24 : tensor<?x?x128x384xi1>{%dim, %dim_19} -> !hal.buffer_view
  util.return %25 : !hal.buffer_view
}

According to the comments in RegionOpUtils.cpp, it seems this has to stay as count_from_dag_root. I'm guessing the unrealized conversion casts are coming from

  %20 = arith.index_cast %extracted_14 : i64 to index
  %21 = arith.index_cast %extracted_18 : i64 to index

I'll look at what is happening with these ops.

@Max191
Copy link
Contributor

Max191 commented Oct 29, 2024

@MaheshRavishankar do we need to add a lowering for flow.dispatch.workgroup_count_from_dag_root?

@MaheshRavishankar
Copy link
Contributor

@MaheshRavishankar do we need to add a lowering for flow.dispatch.workgroup_count_from_dag_root?

No, that really cant be handled very well with scf.forall. But thats not the issue. Hold on.

@MaheshRavishankar
Copy link
Contributor

@pdhirajkumarprasad this is again an issue where none of this has any real computation. This is just index computation for the whole code. We need to get better at triaging this (as a team). I know it fails in codegen, but is not a codegen issue. Labeling as codegen just increases latency.

Not pointing fingers. I didnt see the actual input at all. I just saw the error message too... Just pointing out that we spent two days looking somewhere else :)

@zjgarvey I think this is yours :D

@pdhirajkumarprasad
Copy link
Author

Currently we have following 4 models failing with above error

model--long-t5-tglobal-base-16384-book-summary--pszemraj
model--long-t5-tglobal-base-16384-booksum-V11-big_patent-V2--pszemraj
model--long-t5-tglobal-base-16384-booksum-V12--pszemraj
migraphx_bert__bertsquad-12

@benvanik
Copy link
Collaborator

benvanik commented Nov 6, 2024

  %19 = flow.dispatch.region -> (tensor<i64>) {
    %26 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%8 : tensor<i64>) {
    ^bb0(%out: i64):
      %27 = arith.select %15, %18, %extracted_17 : i64
      linalg.yield %27 : i64
    } -> tensor<i64>
    flow.return %26 : tensor<i64>
  } count() -> (index, index, index) {
    %c1_20 = arith.constant 1 : index
    flow.return %c1_20, %c1_20, %c1_20 : index, index, index
  }
  %extracted_18 = tensor.extract %19[] : tensor<i64>

This is really bad IR - we should be able to compile it, but this is really bad.
All of that could be %extracted_18 = arith.select %15, %18, %extracted_17 : i64 and run literally 10000 faster.
Good as a test of the compiler, but this should be P1 to fix at whatever level is best (I think linalg/detensorizing should have handled this).

@zjgarvey
Copy link
Contributor

zjgarvey commented Nov 6, 2024

This is really bad IR - we should be able to compile it, but this is really bad. All of that could be %extracted_18 = arith.select %15, %18, %extracted_17 : i64 and run literally 10000 faster. Good as a test of the compiler, but this should be P1 to fix at whatever level is best (I think linalg/detensorizing should have handled this).

We shouldn't be getting that kind of IR in the full model since those arith.select ops should get scalarized in the full model.

The problem is that we can't scalarize from a func.func input, so a bit more IR would be necessary for an accurate reproducer (specifically the producers for %arg4 and %arg5).

@zjgarvey
Copy link
Contributor

zjgarvey commented Nov 6, 2024

Currently we have following 4 models failing with above error

model--long-t5-tglobal-base-16384-book-summary--pszemraj
model--long-t5-tglobal-base-16384-booksum-V11-big_patent-V2--pszemraj
model--long-t5-tglobal-base-16384-booksum-V12--pszemraj
migraphx_bert__bertsquad-12

At least the model model--long-t5-tglobal-base-16384-book-summary--pszemraj seems to be failing this conversion i64 to index on an onnx.Einsum op, so it might be good to update with a better reproducer.

@pdhirajkumarprasad
Copy link
Author

Here is new reduced IR from model without much modification so that we keep the data flow intact

module {
  func.func @tf2onnx(%arg0: !torch.vtensor<[?],si64>, %arg1: !torch.vtensor<[?,256],si64>, %arg2: !torch.vtensor<[?,256],si64>, %arg3: !torch.vtensor<[?,256],si64>) -> (!torch.vtensor<[?,256,768],f32> ) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.5.2"} {
    %4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %5 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %9 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %13 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> 
    %14 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> 
    %18 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<2x768xf32>} : () -> !torch.vtensor<[2,768],f32> 
    %483 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<30522x768xf32>} : () -> !torch.vtensor<[30522,768],f32> 
    %484 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<2x768xf32>} : () -> !torch.vtensor<[2,768],f32> 
    %485 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<512x768xf32>} : () -> !torch.vtensor<[512,768],f32> 
    %486 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %487 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %488 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_one_hot_depth_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %489 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_4_shape_0> : tensor<3xsi32>} : () -> !torch.vtensor<[3],si32> 
    %490 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_3_shape_2_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %491 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_3_shape_1_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %492 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_2_shape_0> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> 
    %493 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_1_shape_2_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %494 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_1_shape_1_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %495 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_shape_0> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> 
    %499 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_ExpandDims__48> : tensor<3xsi64>} : () -> !torch.vtensor<[3],si64> 
    %500 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_Reshape_1_shape_2_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %501 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_Reshape_1_shape_1_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %502 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_Reshape_shape_1_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %811 = torch.operator "onnx.Slice"(%485, %14, %13) : (!torch.vtensor<[512,768],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[256,768],f32> 
    %812 = torch.operator "onnx.Cast"(%489) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[3],si32>) -> !torch.vtensor<[3],si64> 
    %813 = torch.operator "onnx.Reshape"(%811, %812) : (!torch.vtensor<[256,768],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,256,768],f32> 
    %814 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %815 = torch.operator "onnx.Unsqueeze"(%490, %814) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %816 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %817 = torch.operator "onnx.Unsqueeze"(%491, %816) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %818 = torch.operator "onnx.Cast"(%492) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[1],si32>) -> !torch.vtensor<[1],si64> 
    %819 = torch.operator "onnx.Reshape"(%arg1, %818) : (!torch.vtensor<[?,256],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],si64> 
    %820 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %821 = torch.operator "onnx.Unsqueeze"(%493, %820) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %822 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %823 = torch.operator "onnx.Unsqueeze"(%494, %822) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %824 = torch.operator "onnx.Cast"(%495) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[1],si32>) -> !torch.vtensor<[1],si64> 
    %825 = torch.operator "onnx.Reshape"(%arg3, %499) : (!torch.vtensor<[?,256],si64>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,256,1],si64> 
    %826 = torch.operator "onnx.Shape"(%825) : (!torch.vtensor<[?,256,1],si64>) -> !torch.vtensor<[3],si64> 
    %827 = torch.operator "onnx.Cast"(%826) {torch.onnx.to = 1 : si64} : (!torch.vtensor<[3],si64>) -> !torch.vtensor<[3],f32> 
    %828 = torch.operator "onnx.Slice"(%827, %9, %8, %7) : (!torch.vtensor<[3],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],f32> 
    %829 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %830 = torch.operator "onnx.Squeeze"(%828, %829) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[],f32> 
    %831 = torch.operator "onnx.Cast"(%830) {torch.onnx.to = 6 : si64} : (!torch.vtensor<[],f32>) -> !torch.vtensor<[],si32> 
    %832 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %833 = torch.operator "onnx.Unsqueeze"(%831, %832) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %834 = torch.operator "onnx.Concat"(%833, %823, %821) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> !torch.vtensor<[3],si32> 
    %835 = torch.operator "onnx.Cast"(%834) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[3],si32>) -> !torch.vtensor<[3],si64> 
    %836 = torch.operator "onnx.Reshape"(%825, %824) : (!torch.vtensor<[?,256,1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],si64> 
    %837 = torch.operator "onnx.Gather"(%483, %836) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[30522,768],f32>, !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,768],f32> 
    %838 = torch.operator "onnx.Reshape"(%837, %835) : (!torch.vtensor<[?,768],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,256,768],f32> 
    %839 = torch.operator "onnx.Shape"(%838) : (!torch.vtensor<[?,256,768],f32>) -> !torch.vtensor<[3],si64> 
    %840 = torch.operator "onnx.Cast"(%839) {torch.onnx.to = 1 : si64} : (!torch.vtensor<[3],si64>) -> !torch.vtensor<[3],f32> 
    %841 = torch.operator "onnx.Slice"(%840, %6, %5, %4) : (!torch.vtensor<[3],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],f32> 
    %842 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %843 = torch.operator "onnx.Squeeze"(%841, %842) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[],f32> 
    %844 = torch.operator "onnx.Cast"(%843) {torch.onnx.to = 6 : si64} : (!torch.vtensor<[],f32>) -> !torch.vtensor<[],si32> 
    %845 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %846 = torch.operator "onnx.Unsqueeze"(%844, %845) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %847 = torch.operator "onnx.Concat"(%846, %817, %815) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> !torch.vtensor<[3],si32> 
    %848 = torch.operator "onnx.Cast"(%847) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[3],si32>) -> !torch.vtensor<[3],si64> 
    %849 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %850 = torch.operator "onnx.Unsqueeze"(%487, %849) : (!torch.vtensor<[],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],f32> 
    %851 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %852 = torch.operator "onnx.Unsqueeze"(%486, %851) : (!torch.vtensor<[],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],f32> 
    %853 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %854 = torch.operator "onnx.Unsqueeze"(%488, %853) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %862 = torch.operator "onnx.Concat"(%850, %852) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[2],f32> 
    %863 = torch.operator "onnx.OneHot"(%819, %854, %862) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[?],si64>, !torch.vtensor<[1],si32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[?,?],f32> 
    %864 = torch.operator "onnx.MatMul"(%863, %484) : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[2,768],f32>) -> !torch.vtensor<[?,768],f32> 
    %865 = torch.operator "onnx.Reshape"(%864, %848) : (!torch.vtensor<[?,768],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,256,768],f32> 
    %866 = torch.operator "onnx.Add"(%838, %865) : (!torch.vtensor<[?,256,768],f32>, !torch.vtensor<[?,256,768],f32>) -> !torch.vtensor<[?,256,768],f32> 
    %867 = torch.operator "onnx.Add"(%866, %813) : (!torch.vtensor<[?,256,768],f32>, !torch.vtensor<[1,256,768],f32>) -> !torch.vtensor<[?,256,768],f32> 
    return %867: !torch.vtensor<[?,256,768],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      _bert_embeddings_one_hot_on_value_0: "0x080000000000803F",
      _bert_embeddings_one_hot_off_value_0: "0x0800000000000000",
      _bert_embeddings_one_hot_depth_0: "0x0800000002000000",
      _bert_embeddings_Reshape_4_shape_0: "0x08000000010000000001000000030000",
      _bert_embeddings_Reshape_3_shape_2_0: "0x0800000000030000",
      _bert_embeddings_Reshape_3_shape_1_0: "0x0800000000010000",
      _bert_embeddings_Reshape_2_shape_0: "0x08000000FFFFFFFF",
      _bert_embeddings_Reshape_1_shape_2_0: "0x0800000000030000",
      _bert_embeddings_Reshape_1_shape_1_0: "0x0800000000010000",
      _bert_embeddings_Reshape_shape_0: "0x08000000FFFFFFFF",
      _bert_embeddings_LayerNorm_batchnorm_add_y_0: "0x08000000CCBC8C2B",
      _bert_embeddings_ExpandDims__48: "0x08000000FFFFFFFFFFFFFFFF00010000000000000100000000000000",
      _Reshape_1_shape_2_0: "0x0800000002000000",
      _Reshape_1_shape_1_0: "0x0800000000010000",
      _Reshape_shape_1_0: "0x0800000000030000",
      _: "0x080000000000803F"
    }
  }
#-}

@zjgarvey
Copy link
Contributor

zjgarvey commented Nov 6, 2024

Thanks @pdhirajkumarprasad , let me verify the scalarization is working properly here.

@zjgarvey
Copy link
Contributor

zjgarvey commented Nov 6, 2024

Nice, it looks like there are just a few casts interrupting the scalarization. I can definitely fix this quickly.

zjgarvey added a commit to llvm/torch-mlir that referenced this issue Nov 12, 2024
…ainderTensorOp` (#3861)

1. adds a lowering for `aten.neg.int` and `aten.remainder.int` to arith.
2. adds a scalarization pattern for `aten.neg` and
`aten.remainder.Tensor` ops.
3. improves folding of `aten.mul.int`
4. adds a scalarization pattern for `aten.to.dtype` which relies on
scalar cast ops and basic C++ casting between `double` and `int64_t`.
5. improves rank-0 case handling for `FoldAtenSplatPattern`
6. removes a bug with `aten.unflatten.int` decomposition incorrectly
generating a constant size int from a dynamic shape.
7. simplifies the dim list for `aten.unflatten.int` ops generated from
the `aten.view` canonicalization in scalarize shapes.

All of these changes were necessary to unblock
<iree-org/iree#18899>.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
Status: No status
Development

No branches or pull requests

6 participants