Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate MLIR with shape information via LTC frontend #742

Conversation

henrytwo
Copy link
Member

@henrytwo henrytwo commented Apr 7, 2022

Previously, the intermediate JIT graph was generated without shape information, which caused generic !torch.tensor types to be scattered throughout the final MLIR. This PR improves the lowering by injecting each jit::Value of tensor type with its corresponding shape from the origin lazy::Node. Additionally, the MLIR parameter and return values are also generated with their respective shapes.

Example output for an FC MNIST model:

func @graph(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1],si64>, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.vtensor<[10,5],f32>, %arg7: !torch.vtensor<[10],f32>, %arg8: !torch.vtensor<*,f32>, %arg9: !torch.float) -> (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[10,5],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[1,10],f32>, !torch.vtensor<*,f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[10,5],f32>) {
  %0 = torch.aten.detach %arg6 :  !torch.vtensor<[10,5],f32> ->  !torch.vtensor<[10,5],f32>
  %int1 = torch.constant.int 1
  %int0 = torch.constant.int 0
  %1 = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
  %int1_0 = torch.constant.int 1
  %int0_1 = torch.constant.int 0
  %2 = torch.prim.ListConstruct %int1_0, %int0_1 : (!torch.int, !torch.int) -> !torch.list<int>
  %3 = torch.aten.permute %0, %2 :  !torch.vtensor<[10,5],f32>, !torch.list<int> ->  !torch.vtensor<[5,10],f32>
  %4 = torch.aten.detach %arg7 :  !torch.vtensor<[10],f32> ->  !torch.vtensor<[10],f32>
  %5 = torch.aten.addmm %4, %arg0, %3, %arg5, %arg4 :  !torch.vtensor<[10],f32>, !torch.vtensor<[1,5],f32>, !torch.vtensor<[5,10],f32>, !torch.int, !torch.int ->  !torch.vtensor<[1,10],f32>
  %6 = torch.aten.relu %5 :  !torch.vtensor<[1,10],f32> ->  !torch.vtensor<[1,10],f32>
  %int1_2 = torch.constant.int 1
  %false = torch.constant.bool false
  %7 = torch.aten._log_softmax %6, %int1_2, %false :  !torch.vtensor<[1,10],f32>, !torch.int, !torch.bool ->  !torch.vtensor<[1,10],f32>
  %none = torch.constant.none
  %int1_3 = torch.constant.int 1
  %int-100 = torch.constant.int -100
  %output, %total_weight = torch.aten.nll_loss_forward %7, %arg1, %none, %int1_3, %int-100 :  !torch.vtensor<[1,10],f32>, !torch.vtensor<[1],si64>, !torch.none, !torch.int, !torch.int ->  !torch.vtensor<*,f32>, !torch.vtensor<*,f32>
  %8 = torch.prim.TupleConstruct %output, %total_weight : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tuple<vtensor, vtensor>
  %none_4 = torch.constant.none
  %int1_5 = torch.constant.int 1
  %int-100_6 = torch.constant.int -100
  %9 = torch.aten.nll_loss_backward %arg8, %7, %arg1, %none_4, %int1_5, %int-100_6, %total_weight :  !torch.vtensor<*,f32>, !torch.vtensor<[1,10],f32>, !torch.vtensor<[1],si64>, !torch.none, !torch.int, !torch.int, !torch.vtensor<*,f32> ->  !torch.vtensor<[1,10],f32>
  %int1_7 = torch.constant.int 1
  %int6 = torch.constant.int 6
  %10 = torch.aten._log_softmax_backward_data %9, %7, %int1_7, %int6 :  !torch.vtensor<[1,10],f32>, !torch.vtensor<[1,10],f32>, !torch.int, !torch.int ->  !torch.vtensor<[1,10],f32>
  %11 = torch.aten.threshold_backward %10, %6, %arg3 :  !torch.vtensor<[1,10],f32>, !torch.vtensor<[1,10],f32>, !torch.int ->  !torch.vtensor<[1,10],f32>
  %int1_8 = torch.constant.int 1
  %int0_9 = torch.constant.int 0
  %12 = torch.prim.ListConstruct %int1_8, %int0_9 : (!torch.int, !torch.int) -> !torch.list<int>
  %int1_10 = torch.constant.int 1
  %int0_11 = torch.constant.int 0
  %13 = torch.prim.ListConstruct %int1_10, %int0_11 : (!torch.int, !torch.int) -> !torch.list<int>
  %14 = torch.aten.permute %arg0, %13 :  !torch.vtensor<[1,5],f32>, !torch.list<int> ->  !torch.vtensor<[5,1],f32>
  %15 = torch.aten.mm %14, %11 :  !torch.vtensor<[5,1],f32>, !torch.vtensor<[1,10],f32> ->  !torch.vtensor<[5,10],f32>
  %int1_12 = torch.constant.int 1
  %int0_13 = torch.constant.int 0
  %16 = torch.prim.ListConstruct %int1_12, %int0_13 : (!torch.int, !torch.int) -> !torch.list<int>
  %int1_14 = torch.constant.int 1
  %int0_15 = torch.constant.int 0
  %17 = torch.prim.ListConstruct %int1_14, %int0_15 : (!torch.int, !torch.int) -> !torch.list<int>
  %18 = torch.aten.permute %15, %17 :  !torch.vtensor<[5,10],f32>, !torch.list<int> ->  !torch.vtensor<[10,5],f32>
  %19 = torch.aten.detach %18 :  !torch.vtensor<[10,5],f32> ->  !torch.vtensor<[10,5],f32>
  %20 = torch.aten.add.Tensor %0, %19, %arg2 :  !torch.vtensor<[10,5],f32>, !torch.vtensor<[10,5],f32>, !torch.float ->  !torch.vtensor<[10,5],f32>
  %int0_16 = torch.constant.int 0
  %21 = torch.prim.ListConstruct %int0_16 : (!torch.int) -> !torch.list<int>
  %true = torch.constant.bool true
  %none_17 = torch.constant.none
  %22 = torch.aten.sum.dim_IntList %11, %21, %true, %none_17 :  !torch.vtensor<[1,10],f32>, !torch.list<int>, !torch.bool, !torch.none ->  !torch.vtensor<[1,10],f32>
  %int10 = torch.constant.int 10
  %23 = torch.prim.ListConstruct %int10 : (!torch.int) -> !torch.list<int>
  %int10_18 = torch.constant.int 10
  %24 = torch.prim.ListConstruct %int10_18 : (!torch.int) -> !torch.list<int>
  %25 = torch.aten.reshape %22, %24 :  !torch.vtensor<[1,10],f32>, !torch.list<int> ->  !torch.vtensor<[10],f32>
  %26 = torch.aten.detach %25 :  !torch.vtensor<[10],f32> ->  !torch.vtensor<[10],f32>
  %27 = torch.aten.add.Tensor %4, %26, %arg9 :  !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float ->  !torch.vtensor<[10],f32>
  return %arg0, %arg1, %20, %27, %6, %output, %26, %19 : !torch.vtensor<[1,5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[10,5],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[1,10],f32>, !torch.vtensor<*,f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[10,5],f32>
}

Note: This PR is based on #725, but GitHub does not allow for setting the base to a fork while keeping this PR in the upstream repo, so there are several shared commits. Only the last one is specifically part of this PR.

Marked as draft while waiting for a dependency PR to land: #725

Resolves: #727

cc: @antoniojkim @ke1337

@silvasean
Copy link
Contributor

Thanks! Looks good.

Can you also see if !torch.tensor can be replaced with !torch.vtensor here? I think the Lazy IR should have value semantics, e.g. the code in lazy_tensor_core/lazy_tensor_core/csrc/view.h

@henrytwo
Copy link
Member Author

henrytwo commented Apr 8, 2022

Hm I took a look at the code and it looks like the only place where torchMlirTorchValueTensorTypeGet is called is here when creating type_bound attributes. The current location for type generation seems to always produce non-value semantic tensor type.

Given that this is how this is setup, what's the expected way to generate value semantic tensors?

I also looked through some of the Lazy IR code (including a file similar to the one you listed) and didn't see anything in there explicitly about value semantics. If I understand correctly though, operations with an underscore are done in place, and therefore shouldn't have value semantics?

@silvasean
Copy link
Contributor

silvasean commented Apr 8, 2022

Hm I took a look at the code and it looks like the only place where torchMlirTorchValueTensorTypeGet is called is here when creating type_bound attributes. The current location for type generation seems to always produce non-value semantic tensor type.

Given that this is how this is setup, what's the expected way to generate value semantic tensors?

I think we would basically need to thread a bool through the relevant code.

We could create an import_options.h with a struct like:

struct ImportOptions {
  bool assumeTensorHaveValueSemantics = false;
};

To make things a bit more organized vs a loose bool.

I also looked through some of the Lazy IR code (including a file similar to the one you listed) and didn't see anything in there explicitly about value semantics. If I understand correctly though, operations with an underscore are done in place, and therefore shouldn't have value semantics?

I think this is a question for the LTC devs about the semantics of the lazy IR. They might prefer the term "purely functional" over "value semantics" but it means the same thing in this context. There are a few other cases besides the trailing underscore variants, such as batch_norm updating the running_mean/var in place.

@henrytwo henrytwo force-pushed the henrytu/ltc_backend_shape_information branch from 8eee430 to b520c09 Compare April 8, 2022 19:08
@henrytwo
Copy link
Member Author

henrytwo commented Apr 8, 2022

assumeTensorHaveValueSemantics

Does this mean you want us to assume tensors do have (i.e. default to) value semantics when generating MLIR, unless if we run into a case where we know that it shouldn't?

@henrytwo
Copy link
Member Author

henrytwo commented Apr 8, 2022

There are a few other cases besides the trailing underscore variants, such as batch_norm updating the running_mean/var in place.

Is there an authoritative list to lookup what should and should not have value semantics? I can see in the tablegen that certain ops are listed with HasValueSemantics. Would this be a sufficient check?

edit: Hm I see that batch_norm is listed with HasValueSemantics even though you mentioned it has some in place ops, so maybe not

@silvasean
Copy link
Contributor

In general the alias annotations (which determine HasValueSemantics) are the source of truth. There are a few exceptions: pytorch/pytorch#73050 (comment)

@henrytwo
Copy link
Member Author

@silvasean Just to clarify before I proceed with implementing these changes, do you want all the tensors to be of type !torch.vtensor when generating MLIR using the LTC frontend, or should this decision be based based on alias annotations + list of exceptions?

@silvasean
Copy link
Contributor

@silvasean Just to clarify before I proceed with implementing these changes, do you want all the tensors to be of type !torch.vtensor when generating MLIR using the LTC frontend, or should this decision be based based on alias annotations + list of exceptions?

I would phrase my request as "where it is trivial to know by construction that the tensor has value semantics, use !torch.vtensor", which I think practically speaking means that if the Lazy IR provides a guarantee about all tensors having value semantics, then use !torch.vtensor, otherwise, just don't worry about it and MaximizeValueSemantics will do the best it can later.

@henrytwo
Copy link
Member Author

henrytwo commented Apr 11, 2022

@silvasean Please have a look at the latest commit. I took a look at this discussion you had with @antoniojkim on Discord it looks like LTC makes all inline operations functional, so I've added a flag to use !torch.vtensor.

I previously misunderstood what you meant in your originally comment, but it makes sense now :)

@henrytwo henrytwo force-pushed the henrytu/ltc_backend_shape_information branch from ede342b to cec0ecf Compare April 11, 2022 21:57
@henrytwo henrytwo requested a review from silvasean April 11, 2022 21:59
Previously, the intermediate JIT graph was generated without shape information, which caused generic !torch.tensor types to be scattered throughout the final MLIR. This PR improves the lowering by injecting each jit::Value of tensor type with its corresponding shape from the origin lazy::Node. Additionally, the MLIR parameter and return values are also generated with their respective shapes.
@henrytwo henrytwo force-pushed the henrytu/ltc_backend_shape_information branch from cec0ecf to aea7899 Compare April 14, 2022 16:57
@henrytwo henrytwo marked this pull request as ready for review April 14, 2022 17:01
// calling-convention-impacting decisions, this flag should be interpreted as
// a requirement to use a value-semantic tensor type (!torch.vtensor) in
// signatures.
bool assumeTensorHaveValueSemantics = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: assumeTensorsHaveValueSemantics

Any inplace ops should be made functional by LTC, so it should be safe to use value semantic tensors for everything.
@henrytwo henrytwo force-pushed the henrytu/ltc_backend_shape_information branch from aea7899 to 4d2aa26 Compare April 14, 2022 18:19
@henrytwo henrytwo merged commit eb73cd0 into llvm:torch_mlir_ltc_backend Apr 14, 2022
@henrytwo henrytwo deleted the henrytu/ltc_backend_shape_information branch April 14, 2022 18:36
antoniojkim pushed a commit that referenced this pull request May 26, 2022
* Generate MLIR with shape information via LTC frontend

Previously, the intermediate JIT graph was generated without shape information, which caused generic !torch.tensor types to be scattered throughout the final MLIR. This PR improves the lowering by injecting each jit::Value of tensor type with its corresponding shape from the origin lazy::Node. Additionally, the MLIR parameter and return values are also generated with their respective shapes.

* Use `!torch.vtensor` for MLIR from LTC

Any inplace ops should be made functional by LTC, so it should be safe to use value semantic tensors for everything.
qedawkins pushed a commit to nod-ai/torch-mlir that referenced this pull request Oct 3, 2022
* change shape inference pass

Signed-off-by: Tong Chen <[email protected]>

* fix function name

Signed-off-by: Tong Chen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants