Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cpu] failed to legalize unresolved materialization from ('i64') to 'index' that remained live after conversion #18899

pdhirajkumarprasad opened this issue Oct 25, 2024 · 18 comments
bug 🐞 Something isn't working


Copy link

What happened?

For the give IR ( IREE compiler version 20241024.1057 @ 9c5b57a )

module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>, %arg3: !torch.vtensor<[?,?,?],si64>, %arg4: !torch.vtensor<[4],si64>, %arg5: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?,128,384],i1>   attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
    %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %2 = torch.operator "onnx.Pad"(%arg1, %arg4, %1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[?,?],si64>, !torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[?,?],si64> 
    %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %4 = torch.operator "onnx.Gather"(%arg5, %3) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %5 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %6 = torch.operator "onnx.Div"(%4, %5) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %9 = torch.operator "onnx.Unsqueeze"(%6, %8) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %10 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<8> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %11 = torch.operator "onnx.Concat"(%7, %9, %10) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3],si64> 
    %12 = torch.operator "onnx.Reshape"(%2, %11) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[?,?],si64>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],si64> 
    %13 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %14 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %15 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %16 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %17 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %18 = torch.operator "onnx.Unsqueeze"(%12, %17) : (!torch.vtensor<[?,?,?],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,?,1],si64> 
    %19 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %20 = torch.operator "onnx.Unsqueeze"(%arg3, %19) : (!torch.vtensor<[?,?,?],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,1,?],si64> 
    %21 = torch.operator "onnx.Cast"(%18) { = 9 : si64} : (!torch.vtensor<[?,?,?,1],si64>) -> !torch.vtensor<[?,?,?,1],i1> 
    %22 = torch.operator "onnx.Cast"(%20) { = 9 : si64} : (!torch.vtensor<[?,?,1,?],si64>) -> !torch.vtensor<[?,?,1,?],i1> 
    %23 = torch.operator "onnx.And"(%21, %22) : (!torch.vtensor<[?,?,?,1],i1>, !torch.vtensor<[?,?,1,?],i1>) -> !torch.vtensor<[?,?,?,?],i1> 
    %24 = torch.operator "onnx.Cast"(%23) { = 9 : si64} : (!torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?,?],i1> 
    %25 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1x1x128x384xi1>} : () -> !torch.vtensor<[1,1,128,384],i1> 
    %26 = torch.operator "onnx.And"(%24, %25) : (!torch.vtensor<[?,?,?,?],i1>, !torch.vtensor<[1,1,128,384],i1>) -> !torch.vtensor<[?,?,128,384],i1> 
    return %26 : !torch.vtensor<[?,?,128,384],i1>

Getting error as

<unknown>:0: error: failed to legalize unresolved materialization from ('i64') to 'index' that remained live after conversion
<unknown>:0: note: see current operation: %6 = "builtin.unrealized_conversion_cast"(%5) : (i64) -> index

Steps to reproduce your issue


iree-compile model.torch_onnx.mlir --iree-hal-target-backends=llvm-cpu -o comp.vmfb --iree-llvmcpu-target-cpu=host

Detail log:


What component(s) does this issue relate to?


Version information

No response

Additional context

No response

Copy link

@pashu123 looks like the scf.forall resolution not kicking in somewhere?

Copy link

@pashu123 looks like the scf.forall resolution not kicking in somewhere?

It's forming a dispatch with dag_root looking into it.

Copy link

Try top of tree. Quinn fixes a few builtins

Copy link

Try top of tree. Quinn fixes a few builtins

I am trying the top of the main. It's not coming from the BuiltIns but forming during the dispatch region creation. Full dump: Line 47735

Copy link

%27 = flow.dispatch.region[%23, %24, %25, %26] -> (tensor<?x?x128x384xi1>{%20, %21}) {
  %30 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%22 : tensor<?x?x128x384xi1>) {
  ^bb0(%out: i1):
    linalg.yield %false : i1
  } -> tensor<?x?x128x384xi1>
  flow.return %30 : tensor<?x?x128x384xi1>
} count(%arg8: index, %arg9: index, %arg10: index, %arg11: index) -> (index, index, index) {
  %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg8, %arg9, %arg10, %arg11
  flow.return %x, %y, %z : index, index, index

wrapOpInDispatchRegion(RewriterBase &rewriter, Operation *op) {
It's forming here.

Copy link

So, rather than creating count_from_dag_root, should it create a count_from_slice here? @MaheshRavishankar

Copy link

Yes! I wonder why that is happening.

Copy link

Max191 commented Oct 29, 2024

Yes! I wonder why that is happening.

It is choosing count_from_dag_root here because the size of the tensor is data dependent. Here is the dump before FormDispatchRegions (see the op at the end of the dump):

// -----// IR Dump After FormScalarDispatchesPass (iree-dispatch-creation-form-scalar-dispatches) //----- //
util.func public @main_graph$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view, %arg5: !hal.buffer_view, %arg6: !hal.fence, %arg7: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
  %false = arith.constant false
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %c0_i64 = arith.constant 0 : i64
  %cst = arith.constant dense<8> : tensor<i64>
  %cst_0 = arith.constant dense<0> : tensor<3xi64>
  %0 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
  %1 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[1] : index
  %2 = hal.tensor.import wait(%arg6) => %arg4 : !hal.buffer_view -> tensor<4xi64>
  %3 = hal.tensor.import wait(%arg6) => %arg5 : !hal.buffer_view -> tensor<2xi64>
  %extracted_slice = tensor.extract_slice %2[0] [1] [1] : tensor<4xi64> to tensor<i64>
  %expanded = tensor.expand_shape %extracted_slice [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted = tensor.extract %expanded[%c0] : tensor<1xi64>
  %extracted_slice_1 = tensor.extract_slice %2[1] [1] [1] : tensor<4xi64> to tensor<i64>
  %expanded_2 = tensor.expand_shape %extracted_slice_1 [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted_3 = tensor.extract %expanded_2[%c0] : tensor<1xi64>
  %extracted_slice_4 = tensor.extract_slice %2[2] [1] [1] : tensor<4xi64> to tensor<i64>
  %expanded_5 = tensor.expand_shape %extracted_slice_4 [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted_6 = tensor.extract %expanded_5[%c0] : tensor<1xi64>
  %extracted_slice_7 = tensor.extract_slice %2[3] [1] [1] : tensor<4xi64> to tensor<i64>
  %expanded_8 = tensor.expand_shape %extracted_slice_7 [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted_9 = tensor.extract %expanded_8[%c0] : tensor<1xi64>
  %4 = arith.index_cast %extracted : i64 to index
  %5 = arith.index_cast %extracted_6 : i64 to index
  %6 = arith.index_cast %extracted_3 : i64 to index
  %7 = arith.index_cast %extracted_9 : i64 to index
  %8 = tensor.empty() : tensor<i64>
  %9 = flow.dispatch.region -> (tensor<i64>) {
    %26 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%8 : tensor<i64>) {
    ^bb0(%out: i64):
      %extracted_20 = tensor.extract %3[%c1] : tensor<2xi64>
      %27 = arith.divsi %extracted_20, %c0_i64 : i64
      linalg.yield %27 : i64
    } -> tensor<i64>
    flow.return %26 : tensor<i64>
  } count() -> (index, index, index) {
    %c1_20 = arith.constant 1 : index
    flow.return %c1_20, %c1_20, %c1_20 : index, index, index
  %inserted_slice = tensor.insert_slice %9 into %cst_0[1] [1] [1] : tensor<i64> into tensor<3xi64>
  %inserted_slice_10 = tensor.insert_slice %cst into %inserted_slice[2] [1] [1] : tensor<i64> into tensor<3xi64>
  %extracted_slice_11 = tensor.extract_slice %inserted_slice_10[0] [1] [1] : tensor<3xi64> to tensor<i64>
  %expanded_12 = tensor.expand_shape %extracted_slice_11 [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted_13 = tensor.extract %expanded_12[%c0] : tensor<1xi64>
  %10 = arith.cmpi eq, %extracted_13, %c0_i64 : i64
  %11 = arith.addi %4, %5 : index
  %12 = arith.addi %11, %0 : index
  %13 = arith.index_cast %12 : index to i64
  %14 = flow.dispatch.region -> (tensor<i64>) {
    %26 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%8 : tensor<i64>) {
    ^bb0(%out: i64):
      %27 = %10, %13, %extracted_13 : i64
      linalg.yield %27 : i64
    } -> tensor<i64>
    flow.return %26 : tensor<i64>
  } count() -> (index, index, index) {
    %c1_20 = arith.constant 1 : index
    flow.return %c1_20, %c1_20, %c1_20 : index, index, index
  %extracted_14 = tensor.extract %14[] : tensor<i64>
  %extracted_slice_15 = tensor.extract_slice %inserted_slice_10[1] [1] [1] : tensor<3xi64> to tensor<i64>
  %expanded_16 = tensor.expand_shape %extracted_slice_15 [] output_shape [1] : tensor<i64> into tensor<1xi64>
  %extracted_17 = tensor.extract %expanded_16[%c0] : tensor<1xi64>
  %15 = arith.cmpi eq, %extracted_17, %c0_i64 : i64
  %16 = arith.addi %6, %7 : index
  %17 = arith.addi %16, %1 : index
  %18 = arith.index_cast %17 : index to i64
  %19 = flow.dispatch.region -> (tensor<i64>) {
    %26 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%8 : tensor<i64>) {
    ^bb0(%out: i64):
      %27 = %15, %18, %extracted_17 : i64
      linalg.yield %27 : i64
    } -> tensor<i64>
    flow.return %26 : tensor<i64>
  } count() -> (index, index, index) {
    %c1_20 = arith.constant 1 : index
    flow.return %c1_20, %c1_20, %c1_20 : index, index, index
  %extracted_18 = tensor.extract %19[] : tensor<i64>
  %20 = arith.index_cast %extracted_14 : i64 to index
  %21 = arith.index_cast %extracted_18 : i64 to index

  // The init operand here is dependent on `%extracted_14` and `%extracted_18`
  %22 = tensor.empty(%20, %21) : tensor<?x?x128x384xi1>
  %23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%22 : tensor<?x?x128x384xi1>) {
  ^bb0(%out: i1):
    linalg.yield %false : i1
  } -> tensor<?x?x128x384xi1>
  %24 = hal.tensor.barrier join(%23 : tensor<?x?x128x384xi1>) => %arg7 : !hal.fence
  %dim = tensor.dim %24, %c0 : tensor<?x?x128x384xi1>
  %dim_19 = tensor.dim %24, %c1 : tensor<?x?x128x384xi1>
  %25 = hal.tensor.export %24 : tensor<?x?x128x384xi1>{%dim, %dim_19} -> !hal.buffer_view
  util.return %25 : !hal.buffer_view

According to the comments in RegionOpUtils.cpp, it seems this has to stay as count_from_dag_root. I'm guessing the unrealized conversion casts are coming from

  %20 = arith.index_cast %extracted_14 : i64 to index
  %21 = arith.index_cast %extracted_18 : i64 to index

I'll look at what is happening with these ops.

Copy link

Max191 commented Oct 29, 2024

@MaheshRavishankar do we need to add a lowering for flow.dispatch.workgroup_count_from_dag_root?

Copy link

@MaheshRavishankar do we need to add a lowering for flow.dispatch.workgroup_count_from_dag_root?

No, that really cant be handled very well with scf.forall. But thats not the issue. Hold on.

Copy link

@pdhirajkumarprasad this is again an issue where none of this has any real computation. This is just index computation for the whole code. We need to get better at triaging this (as a team). I know it fails in codegen, but is not a codegen issue. Labeling as codegen just increases latency.

Not pointing fingers. I didnt see the actual input at all. I just saw the error message too... Just pointing out that we spent two days looking somewhere else :)

@zjgarvey I think this is yours :D

Copy link

Currently we have following 4 models failing with above error


Copy link

benvanik commented Nov 6, 2024

  %19 = flow.dispatch.region -> (tensor<i64>) {
    %26 = linalg.generic {indexing_maps = [affine_map<() -> ()>], iterator_types = []} outs(%8 : tensor<i64>) {
    ^bb0(%out: i64):
      %27 = %15, %18, %extracted_17 : i64
      linalg.yield %27 : i64
    } -> tensor<i64>
    flow.return %26 : tensor<i64>
  } count() -> (index, index, index) {
    %c1_20 = arith.constant 1 : index
    flow.return %c1_20, %c1_20, %c1_20 : index, index, index
  %extracted_18 = tensor.extract %19[] : tensor<i64>

This is really bad IR - we should be able to compile it, but this is really bad.
All of that could be %extracted_18 = %15, %18, %extracted_17 : i64 and run literally 10000 faster.
Good as a test of the compiler, but this should be P1 to fix at whatever level is best (I think linalg/detensorizing should have handled this).

Copy link

zjgarvey commented Nov 6, 2024

This is really bad IR - we should be able to compile it, but this is really bad. All of that could be %extracted_18 = %15, %18, %extracted_17 : i64 and run literally 10000 faster. Good as a test of the compiler, but this should be P1 to fix at whatever level is best (I think linalg/detensorizing should have handled this).

We shouldn't be getting that kind of IR in the full model since those ops should get scalarized in the full model.

The problem is that we can't scalarize from a func.func input, so a bit more IR would be necessary for an accurate reproducer (specifically the producers for %arg4 and %arg5).

Copy link

zjgarvey commented Nov 6, 2024

Currently we have following 4 models failing with above error


At least the model model--long-t5-tglobal-base-16384-book-summary--pszemraj seems to be failing this conversion i64 to index on an onnx.Einsum op, so it might be good to update with a better reproducer.

Copy link

Here is new reduced IR from model without much modification so that we keep the data flow intact

module {
  func.func @tf2onnx(%arg0: !torch.vtensor<[?],si64>, %arg1: !torch.vtensor<[?,256],si64>, %arg2: !torch.vtensor<[?,256],si64>, %arg3: !torch.vtensor<[?,256],si64>) -> (!torch.vtensor<[?,256,768],f32> ) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.5.2"} {
    %4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %5 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %9 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %13 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> 
    %14 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> 
    %18 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<2x768xf32>} : () -> !torch.vtensor<[2,768],f32> 
    %483 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<30522x768xf32>} : () -> !torch.vtensor<[30522,768],f32> 
    %484 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<2x768xf32>} : () -> !torch.vtensor<[2,768],f32> 
    %485 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<512x768xf32>} : () -> !torch.vtensor<[512,768],f32> 
    %486 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %487 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %488 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_one_hot_depth_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %489 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_4_shape_0> : tensor<3xsi32>} : () -> !torch.vtensor<[3],si32> 
    %490 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_3_shape_2_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %491 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_3_shape_1_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %492 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_2_shape_0> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> 
    %493 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_1_shape_2_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %494 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_1_shape_1_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %495 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_Reshape_shape_0> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> 
    %499 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_bert_embeddings_ExpandDims__48> : tensor<3xsi64>} : () -> !torch.vtensor<[3],si64> 
    %500 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_Reshape_1_shape_2_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %501 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_Reshape_1_shape_1_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %502 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_Reshape_shape_1_0> : tensor<si32>} : () -> !torch.vtensor<[],si32> 
    %811 = torch.operator "onnx.Slice"(%485, %14, %13) : (!torch.vtensor<[512,768],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[256,768],f32> 
    %812 = torch.operator "onnx.Cast"(%489) { = 7 : si64} : (!torch.vtensor<[3],si32>) -> !torch.vtensor<[3],si64> 
    %813 = torch.operator "onnx.Reshape"(%811, %812) : (!torch.vtensor<[256,768],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[1,256,768],f32> 
    %814 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %815 = torch.operator "onnx.Unsqueeze"(%490, %814) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %816 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %817 = torch.operator "onnx.Unsqueeze"(%491, %816) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %818 = torch.operator "onnx.Cast"(%492) { = 7 : si64} : (!torch.vtensor<[1],si32>) -> !torch.vtensor<[1],si64> 
    %819 = torch.operator "onnx.Reshape"(%arg1, %818) : (!torch.vtensor<[?,256],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],si64> 
    %820 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %821 = torch.operator "onnx.Unsqueeze"(%493, %820) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %822 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %823 = torch.operator "onnx.Unsqueeze"(%494, %822) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %824 = torch.operator "onnx.Cast"(%495) { = 7 : si64} : (!torch.vtensor<[1],si32>) -> !torch.vtensor<[1],si64> 
    %825 = torch.operator "onnx.Reshape"(%arg3, %499) : (!torch.vtensor<[?,256],si64>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,256,1],si64> 
    %826 = torch.operator "onnx.Shape"(%825) : (!torch.vtensor<[?,256,1],si64>) -> !torch.vtensor<[3],si64> 
    %827 = torch.operator "onnx.Cast"(%826) { = 1 : si64} : (!torch.vtensor<[3],si64>) -> !torch.vtensor<[3],f32> 
    %828 = torch.operator "onnx.Slice"(%827, %9, %8, %7) : (!torch.vtensor<[3],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],f32> 
    %829 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %830 = torch.operator "onnx.Squeeze"(%828, %829) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[],f32> 
    %831 = torch.operator "onnx.Cast"(%830) { = 6 : si64} : (!torch.vtensor<[],f32>) -> !torch.vtensor<[],si32> 
    %832 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %833 = torch.operator "onnx.Unsqueeze"(%831, %832) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %834 = torch.operator "onnx.Concat"(%833, %823, %821) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> !torch.vtensor<[3],si32> 
    %835 = torch.operator "onnx.Cast"(%834) { = 7 : si64} : (!torch.vtensor<[3],si32>) -> !torch.vtensor<[3],si64> 
    %836 = torch.operator "onnx.Reshape"(%825, %824) : (!torch.vtensor<[?,256,1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],si64> 
    %837 = torch.operator "onnx.Gather"(%483, %836) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[30522,768],f32>, !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,768],f32> 
    %838 = torch.operator "onnx.Reshape"(%837, %835) : (!torch.vtensor<[?,768],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,256,768],f32> 
    %839 = torch.operator "onnx.Shape"(%838) : (!torch.vtensor<[?,256,768],f32>) -> !torch.vtensor<[3],si64> 
    %840 = torch.operator "onnx.Cast"(%839) { = 1 : si64} : (!torch.vtensor<[3],si64>) -> !torch.vtensor<[3],f32> 
    %841 = torch.operator "onnx.Slice"(%840, %6, %5, %4) : (!torch.vtensor<[3],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],f32> 
    %842 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %843 = torch.operator "onnx.Squeeze"(%841, %842) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[],f32> 
    %844 = torch.operator "onnx.Cast"(%843) { = 6 : si64} : (!torch.vtensor<[],f32>) -> !torch.vtensor<[],si32> 
    %845 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %846 = torch.operator "onnx.Unsqueeze"(%844, %845) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %847 = torch.operator "onnx.Concat"(%846, %817, %815) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> !torch.vtensor<[3],si32> 
    %848 = torch.operator "onnx.Cast"(%847) { = 7 : si64} : (!torch.vtensor<[3],si32>) -> !torch.vtensor<[3],si64> 
    %849 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %850 = torch.operator "onnx.Unsqueeze"(%487, %849) : (!torch.vtensor<[],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],f32> 
    %851 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %852 = torch.operator "onnx.Unsqueeze"(%486, %851) : (!torch.vtensor<[],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],f32> 
    %853 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %854 = torch.operator "onnx.Unsqueeze"(%488, %853) : (!torch.vtensor<[],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %862 = torch.operator "onnx.Concat"(%850, %852) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[2],f32> 
    %863 = torch.operator "onnx.OneHot"(%819, %854, %862) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[?],si64>, !torch.vtensor<[1],si32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[?,?],f32> 
    %864 = torch.operator "onnx.MatMul"(%863, %484) : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[2,768],f32>) -> !torch.vtensor<[?,768],f32> 
    %865 = torch.operator "onnx.Reshape"(%864, %848) : (!torch.vtensor<[?,768],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,256,768],f32> 
    %866 = torch.operator "onnx.Add"(%838, %865) : (!torch.vtensor<[?,256,768],f32>, !torch.vtensor<[?,256,768],f32>) -> !torch.vtensor<[?,256,768],f32> 
    %867 = torch.operator "onnx.Add"(%866, %813) : (!torch.vtensor<[?,256,768],f32>, !torch.vtensor<[1,256,768],f32>) -> !torch.vtensor<[?,256,768],f32> 
    return %867: !torch.vtensor<[?,256,768],f32>

  dialect_resources: {
    builtin: {
      _bert_embeddings_one_hot_on_value_0: "0x080000000000803F",
      _bert_embeddings_one_hot_off_value_0: "0x0800000000000000",
      _bert_embeddings_one_hot_depth_0: "0x0800000002000000",
      _bert_embeddings_Reshape_4_shape_0: "0x08000000010000000001000000030000",
      _bert_embeddings_Reshape_3_shape_2_0: "0x0800000000030000",
      _bert_embeddings_Reshape_3_shape_1_0: "0x0800000000010000",
      _bert_embeddings_Reshape_2_shape_0: "0x08000000FFFFFFFF",
      _bert_embeddings_Reshape_1_shape_2_0: "0x0800000000030000",
      _bert_embeddings_Reshape_1_shape_1_0: "0x0800000000010000",
      _bert_embeddings_Reshape_shape_0: "0x08000000FFFFFFFF",
      _bert_embeddings_LayerNorm_batchnorm_add_y_0: "0x08000000CCBC8C2B",
      _bert_embeddings_ExpandDims__48: "0x08000000FFFFFFFFFFFFFFFF00010000000000000100000000000000",
      _Reshape_1_shape_2_0: "0x0800000002000000",
      _Reshape_1_shape_1_0: "0x0800000000010000",
      _Reshape_shape_1_0: "0x0800000000030000",
      _: "0x080000000000803F"

Copy link

zjgarvey commented Nov 6, 2024

Thanks @pdhirajkumarprasad , let me verify the scalarization is working properly here.

Copy link

zjgarvey commented Nov 6, 2024

Nice, it looks like there are just a few casts interrupting the scalarization. I can definitely fix this quickly.

zjgarvey added a commit to llvm/torch-mlir that referenced this issue Nov 12, 2024
…ainderTensorOp` (#3861)

1. adds a lowering for `` and `` to arith.
2. adds a scalarization pattern for `aten.neg` and
`aten.remainder.Tensor` ops.
3. improves folding of ``
4. adds a scalarization pattern for `` which relies on
scalar cast ops and basic C++ casting between `double` and `int64_t`.
5. improves rank-0 case handling for `FoldAtenSplatPattern`
6. removes a bug with `` decomposition incorrectly
generating a constant size int from a dynamic shape.
7. simplifies the dim list for `` ops generated from
the `aten.view` canonicalization in scalarize shapes.

All of these changes were necessary to unblock
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
bug 🐞 Something isn't working
Status: No status

No branches or pull requests

6 participants