Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to lower backend-legal operations as TOSA CustomOps #1563

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions include/torch-mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,61 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
}

def ConvertTorchBackendLegalToTosaCustom
: Pass<"convert-torch-backend-legal-to-tosa-custom", "func::FuncOp"> {
let summary = "Convert Torch backend-legal operations to TOSA CustomOps.";
let description = [{
This pass extends the selective decomposition patterns to TOSA backend,
such that it will convert operations marked to custom-ops by the user
for the TOSA backend to CustomOps in TOSA dialect.
The backend-custom operations will not be decomposed or lowered through
the existing Torch to TOSA conversion patterns, even if such patterns
exist in the subsequent Torch to TOSA conversion pass; instead a TOSA
CustomOp will be created with the specifications below:
1. The 'identifier' attribute in the CustomOp will be the operation
name of the corresponding ATen operation.
2. The 'config' attribute in the CustomOp will be the specification
of the compilation pipeline, i.e. "torch_mlir"
3. All inputs to the ATen operation will be converted into legal TOSA
operations. There will be no distiction between inputs and constant
attributes; all will be treated as inputs to the CustomOp. It is up
to the consuming backend to deduce them based on the ATen operation
semantics.
4. Since TOSA conversion pattern of Torch operations are responsible
to consume and convet the constant attributes (such as 'axis' for a
reduction operation), the TOSA CustomOp conversion will also match
and rewrite these attributes as TOSA ConstOps as well. Specifically:
- !torch.bool -> tosa.ConstOp (of tensor type i64)
- !torch.int -> tosa.ConstOp (of tensor type i64)
- !torch.float -> tosa.ConstOp (of tensor type f32)
- !torch.list<int> -> tosa.ConstOp (of tensor type i64)
- !torch.none -> tosa.CustomOp (of tensor type i1)
The 'identifier' attribute of this
CustomOp is 'torch.constant.none'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we know that these are constant, this is what the new implementation_attrs attribute to tosa.custom is meant to hold. It's a string, which means you would need to encode the attributes as a big string, but if we find that is unworkable, we could make it more flexible to contain an ArrayAttr. It's a bit less convenient on my side, as I'd have to come up with an encoding for a serialization pass we have, but it's doable if it makes the op more useful overall.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I have a question on how to deduce wether an input is a constant attribute vs. some input that happens to be constant. Does it suffice for them to have the types listed here to consider them an attribute that needs to be listed in implementation_attrs?
Also, do you have a schema in mind about the encoding?

- !torch.str -> tosa.CustomOp (of tensor type i1)
The 'identifier' attribute of this
CustomOp is 'torch.constant.str'
All other Torch ATen operations will be lowered to TOSA by the Torch
to TOSA conversion pass after this one.
TODO: Extend the contract for other Torch constant operations such as:
- torch.constant.device
- torch.constant.number
5. The input operands of the backend-legal Torch operation and the TOSA
CustomOp will have the same order.
TODO: Update this pass to populate the 'config' attribute of the TOSA
CustomOp to establish a proper operand mapping scheme between
the backend-legal Torch operation and the TOSA CustomOp.
(See the commit at https://github.com/llvm/llvm-project/commit/d94ee70f4f01e4d9eec49e02eff57a5655618401)
}];
let options = [ListOption<
"customOps", "custom-ops", "std::string",
"List of operations considered backend-legal that should be converted to"
" CustomOps in TOSA dialect.",
"llvm::cl::ZeroOrMore">];
let constructor =
"mlir::torch::createConvertTorchBackendLegalToTosaCustomPass(/*customOps=*/{})";
}

def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
let summary = "Convert Torch ops to TOSA ops";
let description = [{
Expand Down
5 changes: 4 additions & 1 deletion include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@

namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchBackendLegalToTosaCustomPass(ArrayRef<std::string> customOps);

std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
}
} // namespace torch
} // namespace mlir

#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
11 changes: 10 additions & 1 deletion include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,16 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);

/// Creates a pipeline that lowers from the torch backend contract to the
/// TOSA backend contract.
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
struct TosaBackendPipelineOptions
: public PassPipelineOptions<TosaBackendPipelineOptions> {
ListOption<std::string> customOps{
*this, "custom-ops",
llvm::cl::desc("List of operations considered backend-legal and "
"should be converted to CustomOps in TOSA dialect.")};
};

void createTorchBackendToTosaBackendPipeline(
OpPassManager &pm, const TosaBackendPipelineOptions &options);

// Do not register the torch-to-mhlo pipeline if mhlo target is disabled
#ifdef TORCH_MLIR_ENABLE_MHLO
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TorchToTosa/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_conversion_library(TorchMLIRTorchToTosa
TorchBackendLegalToTosaCustom.cpp
TorchToTosa.cpp
TosaLegalizeUtils.cpp
TosaLegalizeCommon.cpp
Expand Down
220 changes: 220 additions & 0 deletions lib/Conversion/TorchToTosa/TorchBackendLegalToTosaCustom.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {

class ConvertBackendLegalAtenOpToCustomOp : public ConversionPattern {
public:
SetVector<StringRef> customOps;

ConvertBackendLegalAtenOpToCustomOp(TypeConverter &typeConverter,
MLIRContext *context,
ArrayRef<std::string> customOps)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {
this->customOps = SetVector<StringRef>(customOps.begin(), customOps.end());
}

Value convertOperandToTensor(Value inputOperand, PatternRewriter &rewriter,
Operation *backendLegalOp) const {
// Get the Torch Op to find the constant attributes attached to the
// backendLegalOp
Value torchInputOperand = inputOperand;
if (auto unrealizedCastOp = dyn_cast_or_null<UnrealizedConversionCastOp>(
inputOperand.getDefiningOp())) {
torchInputOperand = unrealizedCastOp.getInputs()[0];
}
Comment on lines +47 to +53
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this is needed instead of using op->getOperand(...) to get inputOperand?

Copy link
Collaborator

@ramiro050 ramiro050 Nov 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see why, you need both the torch dialect operand and also the converted operand for the default case. However, UnrealizedConversionCastOp is an implementation detail of how partial conversions are performed, so I don't think it's a good idea to depend on it. I would prefer passing the operand index as an argument, then doing op->getOperand(operandIndex) to get the torch dialect operand. This makes it very explicit that torchInputOperand is indeed in the torch dialect, which is not the case in the initialization of torchInputOperand in the current implementation.

// Handle the special case where input operand is an argument to the module
// function
Svoch marked this conversation as resolved.
Show resolved Hide resolved
if (!torchInputOperand.getDefiningOp())
return inputOperand;
Comment on lines +56 to +57
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to check that the inputOperand is a tensor


return TypeSwitch<Operation *, Value>(torchInputOperand.getDefiningOp())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I feel that using a TypeSwitch in this particular case makes things more verbose and harder to read than just using if statements. For example, for all the torch.constants, the type switch first dynamic casts to see which case to apply, then inside the cases the same dynamic cast is used inside the matchPattern to get the constant value. We can simply use the matchPattern to see if the input is a torch.constant. Moreover, all the return nullptr lines can be factored out if using if statements

.Case<Torch::ConstantBoolOp>([&](Operation *boolOperand) -> Value {
bool boolConstAttr;
if (matchPattern(boolOperand, m_TorchConstantBool(&boolConstAttr))) {
return tosa::getConstTensor<int64_t>(rewriter, backendLegalOp,
boolConstAttr, {})
.value();
}
return nullptr;
})
// TODO Add support for converting "torch.constant.device"
.Case<Torch::ConstantDeviceOp>(
[&](Operation *strOperand) -> Value { return nullptr; })
.Case<Torch::ConstantIntOp>([&](Operation *intOperand) -> Value {
int64_t intConstAttr;
if (matchPattern(intOperand, m_TorchConstantInt(&intConstAttr))) {
return tosa::getConstTensor<int64_t>(rewriter, backendLegalOp,
intConstAttr, {})
.value();
}
return nullptr;
})
.Case<Torch::ConstantFloatOp>([&](Operation *floatOperand) -> Value {
double floatConstAttr;
if (matchPattern(floatOperand,
m_TorchConstantFloat(&floatConstAttr))) {
return tosa::getConstTensor<float>(rewriter, backendLegalOp,
floatConstAttr, {})
.value();
}
return nullptr;
})
.Case<Torch::ConstantNoneOp>([&](Operation *noneOperand) -> Value {
auto noneCustomOp = rewriter.create<tosa::CustomOp>(
backendLegalOp->getLoc(),
RankedTensorType::get({}, rewriter.getIntegerType(1)),
rewriter.getStringAttr("constant.none"),
rewriter.getStringAttr("torch_mlir"), rewriter.getStringAttr(""),
ValueRange{});
return noneCustomOp.getResult(0);
})
// TODO Add support for converting "torch.constant.number"
.Case<Torch::ConstantNumberOp>(
[&](Operation *strOperand) -> Value { return nullptr; })
.Case<Torch::ConstantStrOp>([&](Operation *strOperand) -> Value {
std::string strConstAttr;
if (matchPattern(strOperand, m_TorchConstantStr(strConstAttr))) {
auto strCustomOp = rewriter.create<tosa::CustomOp>(
backendLegalOp->getLoc(),
RankedTensorType::get({}, rewriter.getIntegerType(8)),
rewriter.getStringAttr("constant.str"),
rewriter.getStringAttr("torch_mlir"),
rewriter.getStringAttr(""), ValueRange{});
return strCustomOp.getResult(0);
}
return nullptr;
})
.Case<Torch::PrimListConstructOp>(
[&](Operation *intListConstructOperand) -> Value {
SmallVector<int64_t> intConstListAttr;
if (matchPattern(intListConstructOperand,
m_TorchListOfConstantInts(intConstListAttr))) {
return tosa::getConstTensor<int64_t>(
rewriter, backendLegalOp, intConstListAttr,
{static_cast<int64_t>(intConstListAttr.size())})
.value();
}
return nullptr;
})
.Default([&](Operation *defaultOperand) { return inputOperand; });
}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {

if (customOps.contains(op->getName().getStringRef())) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SmallVector<Value> customOpInputOperands;

for (auto operand : operands) {
Value customOpInputOperand =
convertOperandToTensor(operand, rewriter, op);
if (!customOpInputOperand) {
return rewriter.notifyMatchFailure(
op,
"failed to match the constant operand of the backend-legal Op");
}
customOpInputOperands.push_back(customOpInputOperand);
}
SmallVector<Type> customOpResultTypes;
auto convertTypesResult = getTypeConverter()->convertTypes(
op->getResultTypes(), customOpResultTypes);
if (convertTypesResult.failed())
return rewriter.notifyMatchFailure(
op, "failed to convert TOSA CustomOp result types; Only tensor "
"types are supported for the resutls.");
rewriter.replaceOpWithNewOp<tosa::CustomOp>(
op, TypeRange{customOpResultTypes},
llvm::StringRef(op->getName().stripDialect()), // identifier
llvm::StringRef("torch_mlir"), // config
llvm::StringRef(""), // implementation_attrs
ValueRange{customOpInputOperands});
return success();
}
return failure();
}
};

} // namespace

// -----------------------------------------------------------------------------
// TorchBackendLegalToTosaCustom Pass
// -----------------------------------------------------------------------------

namespace {
class ConvertTorchBackendLegalToTosaCustom
: public ConvertTorchBackendLegalToTosaCustomBase<
ConvertTorchBackendLegalToTosaCustom> {
public:
ConvertTorchBackendLegalToTosaCustom() = default;
ConvertTorchBackendLegalToTosaCustom(ArrayRef<std::string> customOps) {
this->customOps = customOps;
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tosa::TosaDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}

void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
arith::ArithDialect>();

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);

RewritePatternSet patterns(context);

patterns.add<ConvertBackendLegalAtenOpToCustomOp>(typeConverter, context,
customOps);

for (std::string opName : customOps) {
target.addIllegalOp(OperationName(opName, context));
}

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchBackendLegalToTosaCustomPass(
ArrayRef<std::string> customOps) {
return std::make_unique<ConvertTorchBackendLegalToTosaCustom>(customOps);
}
6 changes: 4 additions & 2 deletions lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void mlir::torch::registerTorchConversionPasses() {
"contract.",
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);

mlir::PassPipelineRegistration<>(
mlir::PassPipelineRegistration<TorchConversion::TosaBackendPipelineOptions>(
"torch-backend-to-tosa-backend-pipeline",
"Pipeline lowering torch backend contract to TOSA backend "
"contract.",
Expand Down Expand Up @@ -96,7 +96,9 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
}

void TorchConversion::createTorchBackendToTosaBackendPipeline(
OpPassManager &pm) {
OpPassManager &pm, const TosaBackendPipelineOptions &options) {
pm.addNestedPass<func::FuncOp>(
createConvertTorchBackendLegalToTosaCustomPass(options.customOps));
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
// Perform rank broadcasting so TosaToLinalg pass works
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
Expand Down
1 change: 1 addition & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
SOURCES
__init__.py
compiler_utils.py
backends/tosa/__init__.py
)

declare_mlir_python_sources(TorchMLIRPythonSources.Dialects
Expand Down
Loading