-
Notifications
You must be signed in to change notification settings - Fork 516
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support to lower backend-legal operations as TOSA CustomOps #1563
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" | ||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" | ||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" | ||
|
||
#include "../PassDetail.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Dialect/Tosa/IR/TosaOps.h" | ||
#include "mlir/Dialect/Traits.h" | ||
#include "mlir/IR/Matchers.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" | ||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" | ||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" | ||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" | ||
#include "llvm/ADT/SetVector.h" | ||
#include "llvm/ADT/TypeSwitch.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::torch; | ||
using namespace mlir::torch::Torch; | ||
|
||
namespace { | ||
|
||
class ConvertBackendLegalAtenOpToCustomOp : public ConversionPattern { | ||
public: | ||
SetVector<StringRef> customOps; | ||
|
||
ConvertBackendLegalAtenOpToCustomOp(TypeConverter &typeConverter, | ||
MLIRContext *context, | ||
ArrayRef<std::string> customOps) | ||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) { | ||
this->customOps = SetVector<StringRef>(customOps.begin(), customOps.end()); | ||
} | ||
|
||
Value convertOperandToTensor(Value inputOperand, PatternRewriter &rewriter, | ||
Operation *backendLegalOp) const { | ||
// Get the Torch Op to find the constant attributes attached to the | ||
// backendLegalOp | ||
Value torchInputOperand = inputOperand; | ||
if (auto unrealizedCastOp = dyn_cast_or_null<UnrealizedConversionCastOp>( | ||
inputOperand.getDefiningOp())) { | ||
torchInputOperand = unrealizedCastOp.getInputs()[0]; | ||
} | ||
Comment on lines
+47
to
+53
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason this is needed instead of using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see why, you need both the torch dialect operand and also the converted operand for the default case. However, |
||
// Handle the special case where input operand is an argument to the module | ||
// function | ||
Svoch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (!torchInputOperand.getDefiningOp()) | ||
return inputOperand; | ||
Comment on lines
+56
to
+57
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to check that the |
||
|
||
return TypeSwitch<Operation *, Value>(torchInputOperand.getDefiningOp()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I feel that using a |
||
.Case<Torch::ConstantBoolOp>([&](Operation *boolOperand) -> Value { | ||
bool boolConstAttr; | ||
if (matchPattern(boolOperand, m_TorchConstantBool(&boolConstAttr))) { | ||
return tosa::getConstTensor<int64_t>(rewriter, backendLegalOp, | ||
boolConstAttr, {}) | ||
.value(); | ||
} | ||
return nullptr; | ||
}) | ||
// TODO Add support for converting "torch.constant.device" | ||
.Case<Torch::ConstantDeviceOp>( | ||
[&](Operation *strOperand) -> Value { return nullptr; }) | ||
.Case<Torch::ConstantIntOp>([&](Operation *intOperand) -> Value { | ||
int64_t intConstAttr; | ||
if (matchPattern(intOperand, m_TorchConstantInt(&intConstAttr))) { | ||
return tosa::getConstTensor<int64_t>(rewriter, backendLegalOp, | ||
intConstAttr, {}) | ||
.value(); | ||
} | ||
return nullptr; | ||
}) | ||
.Case<Torch::ConstantFloatOp>([&](Operation *floatOperand) -> Value { | ||
double floatConstAttr; | ||
if (matchPattern(floatOperand, | ||
m_TorchConstantFloat(&floatConstAttr))) { | ||
return tosa::getConstTensor<float>(rewriter, backendLegalOp, | ||
floatConstAttr, {}) | ||
.value(); | ||
} | ||
return nullptr; | ||
}) | ||
.Case<Torch::ConstantNoneOp>([&](Operation *noneOperand) -> Value { | ||
auto noneCustomOp = rewriter.create<tosa::CustomOp>( | ||
backendLegalOp->getLoc(), | ||
RankedTensorType::get({}, rewriter.getIntegerType(1)), | ||
rewriter.getStringAttr("constant.none"), | ||
rewriter.getStringAttr("torch_mlir"), rewriter.getStringAttr(""), | ||
ValueRange{}); | ||
return noneCustomOp.getResult(0); | ||
}) | ||
// TODO Add support for converting "torch.constant.number" | ||
.Case<Torch::ConstantNumberOp>( | ||
[&](Operation *strOperand) -> Value { return nullptr; }) | ||
.Case<Torch::ConstantStrOp>([&](Operation *strOperand) -> Value { | ||
std::string strConstAttr; | ||
if (matchPattern(strOperand, m_TorchConstantStr(strConstAttr))) { | ||
auto strCustomOp = rewriter.create<tosa::CustomOp>( | ||
backendLegalOp->getLoc(), | ||
RankedTensorType::get({}, rewriter.getIntegerType(8)), | ||
rewriter.getStringAttr("constant.str"), | ||
rewriter.getStringAttr("torch_mlir"), | ||
rewriter.getStringAttr(""), ValueRange{}); | ||
return strCustomOp.getResult(0); | ||
} | ||
return nullptr; | ||
}) | ||
.Case<Torch::PrimListConstructOp>( | ||
[&](Operation *intListConstructOperand) -> Value { | ||
SmallVector<int64_t> intConstListAttr; | ||
if (matchPattern(intListConstructOperand, | ||
m_TorchListOfConstantInts(intConstListAttr))) { | ||
return tosa::getConstTensor<int64_t>( | ||
rewriter, backendLegalOp, intConstListAttr, | ||
{static_cast<int64_t>(intConstListAttr.size())}) | ||
.value(); | ||
} | ||
return nullptr; | ||
}) | ||
.Default([&](Operation *defaultOperand) { return inputOperand; }); | ||
} | ||
|
||
LogicalResult | ||
matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||
ConversionPatternRewriter &rewriter) const final { | ||
|
||
if (customOps.contains(op->getName().getStringRef())) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: return early to avoid indentation https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code |
||
SmallVector<Value> customOpInputOperands; | ||
|
||
for (auto operand : operands) { | ||
Value customOpInputOperand = | ||
convertOperandToTensor(operand, rewriter, op); | ||
if (!customOpInputOperand) { | ||
return rewriter.notifyMatchFailure( | ||
op, | ||
"failed to match the constant operand of the backend-legal Op"); | ||
} | ||
customOpInputOperands.push_back(customOpInputOperand); | ||
} | ||
SmallVector<Type> customOpResultTypes; | ||
auto convertTypesResult = getTypeConverter()->convertTypes( | ||
op->getResultTypes(), customOpResultTypes); | ||
if (convertTypesResult.failed()) | ||
return rewriter.notifyMatchFailure( | ||
op, "failed to convert TOSA CustomOp result types; Only tensor " | ||
"types are supported for the resutls."); | ||
rewriter.replaceOpWithNewOp<tosa::CustomOp>( | ||
op, TypeRange{customOpResultTypes}, | ||
llvm::StringRef(op->getName().stripDialect()), // identifier | ||
llvm::StringRef("torch_mlir"), // config | ||
llvm::StringRef(""), // implementation_attrs | ||
ValueRange{customOpInputOperands}); | ||
return success(); | ||
} | ||
return failure(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
// ----------------------------------------------------------------------------- | ||
// TorchBackendLegalToTosaCustom Pass | ||
// ----------------------------------------------------------------------------- | ||
|
||
namespace { | ||
class ConvertTorchBackendLegalToTosaCustom | ||
: public ConvertTorchBackendLegalToTosaCustomBase< | ||
ConvertTorchBackendLegalToTosaCustom> { | ||
public: | ||
ConvertTorchBackendLegalToTosaCustom() = default; | ||
ConvertTorchBackendLegalToTosaCustom(ArrayRef<std::string> customOps) { | ||
this->customOps = customOps; | ||
} | ||
|
||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<tosa::TosaDialect>(); | ||
registry.insert<tensor::TensorDialect>(); | ||
registry.insert<arith::ArithDialect>(); | ||
TorchConversion::getBackendTypeConversionDependentDialects(registry); | ||
} | ||
|
||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
ConversionTarget target(*context); | ||
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect, | ||
arith::ArithDialect>(); | ||
|
||
TypeConverter typeConverter; | ||
typeConverter.addConversion([](Type type) { return type; }); | ||
TorchConversion::setupBackendTypeConversion(target, typeConverter); | ||
|
||
RewritePatternSet patterns(context); | ||
|
||
patterns.add<ConvertBackendLegalAtenOpToCustomOp>(typeConverter, context, | ||
customOps); | ||
|
||
for (std::string opName : customOps) { | ||
target.addIllegalOp(OperationName(opName, context)); | ||
} | ||
|
||
if (failed(applyPartialConversion(getOperation(), target, | ||
std::move(patterns)))) | ||
return signalPassFailure(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
std::unique_ptr<OperationPass<func::FuncOp>> | ||
mlir::torch::createConvertTorchBackendLegalToTosaCustomPass( | ||
ArrayRef<std::string> customOps) { | ||
return std::make_unique<ConvertTorchBackendLegalToTosaCustom>(customOps); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we know that these are constant, this is what the new implementation_attrs attribute to tosa.custom is meant to hold. It's a string, which means you would need to encode the attributes as a big string, but if we find that is unworkable, we could make it more flexible to contain an ArrayAttr. It's a bit less convenient on my side, as I'd have to come up with an encoding for a serialization pass we have, but it's doable if it makes the op more useful overall.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I have a question on how to deduce wether an input is a constant attribute vs. some input that happens to be constant. Does it suffice for them to have the types listed here to consider them an attribute that needs to be listed in
implementation_attrs
?Also, do you have a schema in mind about the encoding?