-
Notifications
You must be signed in to change notification settings - Fork 514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][TORCH] Add -convert-torch-to-tosa-custom pass for undecomposed selective ops #1514
Conversation
The CI error is due to this patch. The program is crashing when creating the https://github.com/llvm/torch-mlir/actions/runs/3292098619/jobs/5427059644#step:6:6218 |
I don't think this is the right approach for what this patch is trying to achieve. As a user of torch-mlir, I would expect that running the I think this requires a more general solution where users can specify which torch-ops to pass down as custom ops, and have the conversion to |
The werid thing is when I run https://github.com/llvm/torch-mlir/actions/runs/3292098619/jobs/5427059644#step:6:6218:~:text=6191-,%2D%2D,%2D%2D,-6194 on my local machine, everything looks good. Here is my local run cmd print:
|
Yes. The RFC is on the way by my partner. I just try to do something before the RFC is done. This is the initial step for demonstrating custom op could be used to handle complex ops that we don't want to decompose. There should be a pass/list where we add ops and only ops from this list should go through the custom op route, the rest of op should be lower normally. |
Let's wait for the RFC to discuss more. Per-op custom patterns are definitely not the right way to do this. Maybe we could have a pass |
This is an interesting PR. Integer domain softmax could be done using a custom op, but need not be. There is a working and bit accurate decomposition of TFLite softmax that is fairly detailed but works: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc#L1375 Please note that the legalize_common.cc there is the same as TosaLegalizeCommon.cc - the goal here is to eventually move these builders into the llvm-project TOSA codebase. The straightforward solution here is to copy the convertSoftMaxOp() into TosaLegalizeCommon.cpp, then just call that. It should do all the work for you. This approach is used by several other TorchToTosa legalizations that invoke calls to TosaLegalizeCommon to implement legalizations using previously solved pattern replacements. Note: the goal is to keep TosaLegalizeCommon.* and TosaLegalizeUtils.* files essentially identical to their TF counterparts legalize_common.* and legalize_utils.* so that a followon PR can move them upstream since multiple projects can all demonstrably depend on the same conversions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see previous comment.
The RFC is ready here. #1519 @ramiro050 @sjarus |
Adding @eric-k256 |
ad1779e
to
1ea3d37
Compare
Instead of add a new pass torch-to-tosa-convert-to-custom{op-list=aten.softmax.int,...}. I add an options in the exist pass -convert-torch-to-tosa. Here is my most recent local result:
|
8d0c781
to
c6c35b6
Compare
Based on Svoch requirements, move dim into input tensor instead of attributes, Here is the new output:
|
Just got an E2E test bug about the option I added:
|
Option<bool> custom{*this, "custom-ops", | ||
llvm::cl::desc("custom complex operations."), | ||
llvm::cl::init(true)}; | ||
ListOption<std::string> customOps{ | ||
*this, "custom-ops", | ||
llvm::cl::desc("List of ops to be converted to the backend."), | ||
llvm::cl::ZeroOrMore}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the build issue is here. There are two options both with the same name "custom-ops"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense. Thanks, Ramiro. I have this fixed, could you review and approved this patch if everything else makes sense?
c6c35b6
to
c510133
Compare
Please review this patch again. The purpose of this patch is to work with the op that we choose to not decompose. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels halfway between being a one-off solution for torch.aten.softmax.int, and providing a flexible way to handle many ops in this manner. With some relatively small changes, I think it can be an example use of the flexible mechanism.
@@ -3854,6 +3902,9 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> { | |||
INSERT_ATENOP_PATTERN(AtenSliceTensorOp); | |||
INSERT_ATENOP_PATTERN(AtenBroadcastToOp); | |||
INSERT_ATENOP_PATTERN(AtenWhereSelfOp); | |||
if(!customOps.empty()){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels awkward, and isn't checking the value of the customOps to make sure that softmax is in the customOp list. I think a better option is to always insert the pattern, and then at the top of the softmax matchAndRewrite, search for torch.aten.softmax.int in the string, and use that to determine whether to create the custom op.
Some checking that the customOp list is valid would be good. With this change, someone could put any op in the customOp list and get no warning that the pass will not do what is being asked. At least having a list of ops that support legalization to tosa.custom could be used to warn if an op was passed that wasn't expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
: public PassPipelineOptions<TosaBackendPipelineOptions> { | ||
// If this option is true, custom complex operations. | ||
// If this option is false, skip decomposition of complex operations. | ||
Option<bool> enableCustomOps{*this, "enableCustomOps", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any other reference to this, is this option needed, or is the custom-ops ListOption sufficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deleted
@AmosLewis, can you address this? |
01ac54b
to
62233fd
Compare
I didn't find out what does the interface() means. Could you point out where is the code for it? |
We had thoughts of making bigger changes to tosa.custom, and calling it tosa.interface. That hasn't happened, but I do have a proposed change to tosa.custom in the dialect here: https://reviews.llvm.org/D137133 The other point Suraj was making is that we do have a legalization for softmax int into a series of TOSA ops that could be used to work aorund this specific problem. The larger problem of conversion from torch to tosa.custom remains, which the pass described by Ramiro aims to solve. |
68d68bb
to
3a0be8a
Compare
bb7b458
to
1a70859
Compare
@ramiro050 I just added a generic MatchAndRewrite for torch-to-tosa-custom pass. Now, there is no need to write new MatchAndRewrite for new ops. It worked for the 2 examples in custom.mlir.We can add more type conversion support into it later. Please review. |
This patch is not a special fix for softmax, please review the new update.
|
||
namespace { | ||
|
||
template <typename AtenOpT> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be a template. It should be a class pattern that works on Operation *
inputs. Here's an example:
torch-mlir/lib/Conversion/TorchToLinalg/Uncategorized.cpp
Lines 1031 to 1038 in fedf8c0
class ConvertElementwiseOp : public ConversionPattern { | |
public: | |
ConvertElementwiseOp(TypeConverter &typeConverter, MLIRContext *context) | |
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, | |
context) {} | |
LogicalResult | |
matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
LogicalResult | ||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
ValueRange adaptor_operands = adaptor.getOperands(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In torch-mlir, we use camel-case
int num_operands = adaptor_operands.size(); | ||
std::vector<mlir::Value> inputs_vec; | ||
for (int i = 0; i < num_operands; i++) { | ||
auto operand = *op.getODSOperands(i).begin(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can you this by calling operand = op.getOperand(i);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
auto operand = *op.getODSOperands(i).begin(); | ||
auto adaptor_operand_type = adaptor_operands[i].getType(); | ||
// type convert for operands | ||
if (adaptor_operand_type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There shouldn't be a need to check the adaptor types. matchPattern
will check if the input is of the expected type or not.
if (!matchPattern(operand, m_TorchConstantFloat(&operand_tosa))) | ||
return rewriter.notifyMatchFailure( | ||
op, "unimplemented: operand should be a torch.constant.float"); | ||
auto operand_tensor_float = tosa::getConstTensor<int64_t>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type of this tensor should be double
, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
let description = [{ | ||
The purpose to use tosa::custom is handle complex ops when we donnot | ||
want to decompose them into simple ops. | ||
The aten op name will used to construct a StringAttr as the identifier attribute for tosa::CustomOp. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will used
-> will be used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
The purpose to use tosa::custom is handle complex ops when we donnot | ||
want to decompose them into simple ops. | ||
The aten op name will used to construct a StringAttr as the identifier attribute for tosa::CustomOp. | ||
So in the output the users know where is the tosa::CustomOp is coming from by the StringAttr id. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is the
-> where the
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
operand of tosa::CustomOp op. After convert, use ValueRange/SmallVector to include | ||
all operands as the final input operands for tosa::CustomOp. | ||
The contract to follow: | ||
AnyTorchTensorType -> TensorType<?xAnyType> of tosa::ConstOp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I would replace ?
with AnySize
, since ?
means the type has a single dynamic dimension
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
operand of tosa::CustomOp op. After convert, use ValueRange/SmallVector to include | ||
all operands as the final input operands for tosa::CustomOp. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this last sentence is necessary. It's more of an implementation detail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
AnyTorchTensorType -> TensorType<?xAnyType> of tosa::ConstOp | ||
Torch_IntType -> RankedTensorType<1xi64> of tosa::ConstOp | ||
Torch_FloatType -> RankedTensorType<1xf32> of tosa::ConstOp | ||
... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also add a TODO explaining the things that are currently unsupported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
da236fc
to
14fd994
Compare
Folks, I believe the approach I have in my recent PR (which is in reference to the discussion in this Selective Op decomposition RFC) addresses the outstanding comments in this PR. @AmosLewis - do you mind if we move the discussion over to the other PR? |
Sure, Let's relocate to your patch. |
14fd994
to
86625b9
Compare
Use tosa::custom op to represent softmax op which we deliberately choose not to decompose into simple ops.The reason to do this is in this RFC here: #1519