Skip to content

Commit

Permalink
Add support to lower backend-legal operations as TOSA CustomOps
Browse files Browse the repository at this point in the history
Add a "backends.tosa" module for specifiying backend-custom-ops
 - These operations will be converted to CustomOps
  • Loading branch information
Svoch committed Nov 21, 2022
1 parent 22307a1 commit e3c3852
Show file tree
Hide file tree
Showing 9 changed files with 384 additions and 4 deletions.
55 changes: 55 additions & 0 deletions include/torch-mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,61 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
}

def ConvertTorchBackendLegalToTosaCustom
: Pass<"convert-torch-backend-legal-to-tosa-custom", "func::FuncOp"> {
let summary = "Convert Torch backend-legal operations to TOSA CustomOps.";
let description = [{
This pass extends the selective decomposition patterns to TOSA backend,
such that it will convert operations marked to custom-ops by the user
for the TOSA backend to CustomOps in TOSA dialect.
The backend-custom operations will not be decomposed or lowered through
the existing Torch to TOSA conversion patterns, even if such patterns
exist in the subsequent Torch to TOSA conversion pass; instead a TOSA
CustomOp will be created with the specifications below:
1. The 'identifier' attribute in the CustomOp will be the operation
name of the corresponding ATen operation.
2. The 'config' attribute in the CustomOp will be the specification
of the compilation pipeline, i.e. "torch_mlir"
3. All inputs to the ATen operation will be converted into legal TOSA
operations. There will be no distiction between inputs and constant
attributes; all will be treated as inputs to the CustomOp. It is up
to the consuming backend to deduce them based on the ATen operation
semantics.
4. Since TOSA conversion pattern of Torch operations are responsible
to consume and convet the constant attributes (such as 'axis' for a
reduction operation), the TOSA CustomOp conversion will also match
and rewrite these attributes as TOSA ConstOps as well. Specifically:
- !torch.bool -> tosa.ConstOp (of tensor type i64)
- !torch.int -> tosa.ConstOp (of tensor type i64)
- !torch.float -> tosa.ConstOp (of tensor type f32)
- !torch.list<int> -> tosa.ConstOp (of tensor type i64)
- !torch.none -> tosa.CustomOp (of tensor type i1)
The 'identifier' attribute of this
CustomOp is 'torch.constant.none'
- !torch.str -> tosa.CustomOp (of tensor type i1)
The 'identifier' attribute of this
CustomOp is 'torch.constant.str'
All other Torch ATen operations will be lowered to TOSA by the Torch
to TOSA conversion pass after this one.
TODO: Extend the contract for other Torch constant operations such as:
- torch.constant.device
- torch.constant.number
5. The input operands of the backend-legal Torch operation and the TOSA
CustomOp will have the same order.
TODO: Update this pass to populate the 'config' attribute of the TOSA
CustomOp to establish a proper operand mapping scheme between
the backend-legal Torch operation and the TOSA CustomOp.
(See the commit at https://github.com/llvm/llvm-project/commit/d94ee70f4f01e4d9eec49e02eff57a5655618401)
}];
let options = [ListOption<
"customOps", "custom-ops", "std::string",
"List of operations considered backend-legal that should be converted to"
" CustomOps in TOSA dialect.",
"llvm::cl::ZeroOrMore">];
let constructor =
"mlir::torch::createConvertTorchBackendLegalToTosaCustomPass(/*customOps=*/{})";
}

def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
let summary = "Convert Torch ops to TOSA ops";
let description = [{
Expand Down
5 changes: 4 additions & 1 deletion include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@

namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchBackendLegalToTosaCustomPass(ArrayRef<std::string> customOps);

std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
}
} // namespace torch
} // namespace mlir

#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
11 changes: 10 additions & 1 deletion include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,16 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);

/// Creates a pipeline that lowers from the torch backend contract to the
/// TOSA backend contract.
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
struct TosaBackendPipelineOptions
: public PassPipelineOptions<TosaBackendPipelineOptions> {
ListOption<std::string> customOps{
*this, "custom-ops",
llvm::cl::desc("List of operations considered backend-legal and "
"should be converted to CustomOps in TOSA dialect.")};
};

void createTorchBackendToTosaBackendPipeline(
OpPassManager &pm, const TosaBackendPipelineOptions &options);

// Do not register the torch-to-mhlo pipeline if mhlo target is disabled
#ifdef TORCH_MLIR_ENABLE_MHLO
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TorchToTosa/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_conversion_library(TorchMLIRTorchToTosa
TorchBackendLegalToTosaCustom.cpp
TorchToTosa.cpp
TosaLegalizeUtils.cpp
TosaLegalizeCommon.cpp
Expand Down
220 changes: 220 additions & 0 deletions lib/Conversion/TorchToTosa/TorchBackendLegalToTosaCustom.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {

class ConvertBackendLegalAtenOpToCustomOp : public ConversionPattern {
public:
SetVector<StringRef> customOps;

ConvertBackendLegalAtenOpToCustomOp(TypeConverter &typeConverter,
MLIRContext *context,
ArrayRef<std::string> customOps)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {
this->customOps = SetVector<StringRef>(customOps.begin(), customOps.end());
}

Value convertOperandToTensor(Value inputOperand, PatternRewriter &rewriter,
Operation *backendLegalOp) const {
// Get the Torch Op to find the constant attributes attached to the
// backendLegalOp
Value torchInputOperand = inputOperand;
if (auto unrealizedCastOp = dyn_cast_or_null<UnrealizedConversionCastOp>(
inputOperand.getDefiningOp())) {
torchInputOperand = unrealizedCastOp.getInputs()[0];
}
// Handle the special case where input operand is an argument to the module
// function
if (!torchInputOperand.getDefiningOp())
return inputOperand;

return TypeSwitch<Operation *, Value>(torchInputOperand.getDefiningOp())
.Case<Torch::ConstantBoolOp>([&](Operation *boolOperand) -> Value {
bool boolConstAttr;
if (matchPattern(boolOperand, m_TorchConstantBool(&boolConstAttr))) {
return tosa::getConstTensor<int64_t>(rewriter, backendLegalOp,
boolConstAttr, {})
.value();
}
return nullptr;
})
// TODO Add support for converting "torch.constant.device"
.Case<Torch::ConstantDeviceOp>(
[&](Operation *strOperand) -> Value { return nullptr; })
.Case<Torch::ConstantIntOp>([&](Operation *intOperand) -> Value {
int64_t intConstAttr;
if (matchPattern(intOperand, m_TorchConstantInt(&intConstAttr))) {
return tosa::getConstTensor<int64_t>(rewriter, backendLegalOp,
intConstAttr, {})
.value();
}
return nullptr;
})
.Case<Torch::ConstantFloatOp>([&](Operation *floatOperand) -> Value {
double floatConstAttr;
if (matchPattern(floatOperand,
m_TorchConstantFloat(&floatConstAttr))) {
return tosa::getConstTensor<float>(rewriter, backendLegalOp,
floatConstAttr, {})
.value();
}
return nullptr;
})
.Case<Torch::ConstantNoneOp>([&](Operation *noneOperand) -> Value {
auto noneCustomOp = rewriter.create<tosa::CustomOp>(
backendLegalOp->getLoc(),
RankedTensorType::get({}, rewriter.getIntegerType(1)),
rewriter.getStringAttr("constant.none"),
rewriter.getStringAttr("torch_mlir"), rewriter.getStringAttr(""),
ValueRange{});
return noneCustomOp.getResult(0);
})
// TODO Add support for converting "torch.constant.number"
.Case<Torch::ConstantNumberOp>(
[&](Operation *strOperand) -> Value { return nullptr; })
.Case<Torch::ConstantStrOp>([&](Operation *strOperand) -> Value {
std::string strConstAttr;
if (matchPattern(strOperand, m_TorchConstantStr(strConstAttr))) {
auto strCustomOp = rewriter.create<tosa::CustomOp>(
backendLegalOp->getLoc(),
RankedTensorType::get({}, rewriter.getIntegerType(8)),
rewriter.getStringAttr("constant.str"),
rewriter.getStringAttr("torch_mlir"),
rewriter.getStringAttr(""), ValueRange{});
return strCustomOp.getResult(0);
}
return nullptr;
})
.Case<Torch::PrimListConstructOp>(
[&](Operation *intListConstructOperand) -> Value {
SmallVector<int64_t> intConstListAttr;
if (matchPattern(intListConstructOperand,
m_TorchListOfConstantInts(intConstListAttr))) {
return tosa::getConstTensor<int64_t>(
rewriter, backendLegalOp, intConstListAttr,
{static_cast<int64_t>(intConstListAttr.size())})
.value();
}
return nullptr;
})
.Default([&](Operation *defaultOperand) { return inputOperand; });
}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {

if (customOps.contains(op->getName().getStringRef())) {
SmallVector<Value> customOpInputOperands;

for (auto operand : operands) {
Value customOpInputOperand =
convertOperandToTensor(operand, rewriter, op);
if (!customOpInputOperand) {
return rewriter.notifyMatchFailure(
op,
"failed to match the constant operand of the backend-legal Op");
}
customOpInputOperands.push_back(customOpInputOperand);
}
SmallVector<Type> customOpResultTypes;
auto convertTypesResult = getTypeConverter()->convertTypes(
op->getResultTypes(), customOpResultTypes);
if (convertTypesResult.failed())
return rewriter.notifyMatchFailure(
op, "failed to convert TOSA CustomOp result types; Only tensor "
"types are supported for the resutls.");
rewriter.replaceOpWithNewOp<tosa::CustomOp>(
op, TypeRange{customOpResultTypes},
llvm::StringRef(op->getName().stripDialect()), // identifier
llvm::StringRef("torch_mlir"), // config
llvm::StringRef(""), // implementation_attrs
ValueRange{customOpInputOperands});
return success();
}
return failure();
}
};

} // namespace

// -----------------------------------------------------------------------------
// TorchBackendLegalToTosaCustom Pass
// -----------------------------------------------------------------------------

namespace {
class ConvertTorchBackendLegalToTosaCustom
: public ConvertTorchBackendLegalToTosaCustomBase<
ConvertTorchBackendLegalToTosaCustom> {
public:
ConvertTorchBackendLegalToTosaCustom() = default;
ConvertTorchBackendLegalToTosaCustom(ArrayRef<std::string> customOps) {
this->customOps = customOps;
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tosa::TosaDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}

void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
arith::ArithDialect>();

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);

RewritePatternSet patterns(context);

patterns.add<ConvertBackendLegalAtenOpToCustomOp>(typeConverter, context,
customOps);

for (std::string opName : customOps) {
target.addIllegalOp(OperationName(opName, context));
}

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchBackendLegalToTosaCustomPass(
ArrayRef<std::string> customOps) {
return std::make_unique<ConvertTorchBackendLegalToTosaCustom>(customOps);
}
6 changes: 4 additions & 2 deletions lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void mlir::torch::registerTorchConversionPasses() {
"contract.",
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);

mlir::PassPipelineRegistration<>(
mlir::PassPipelineRegistration<TorchConversion::TosaBackendPipelineOptions>(
"torch-backend-to-tosa-backend-pipeline",
"Pipeline lowering torch backend contract to TOSA backend "
"contract.",
Expand Down Expand Up @@ -96,7 +96,9 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
}

void TorchConversion::createTorchBackendToTosaBackendPipeline(
OpPassManager &pm) {
OpPassManager &pm, const TosaBackendPipelineOptions &options) {
pm.addNestedPass<func::FuncOp>(
createConvertTorchBackendLegalToTosaCustomPass(options.customOps));
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
// Perform rank broadcasting so TosaToLinalg pass works
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
Expand Down
1 change: 1 addition & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
SOURCES
__init__.py
compiler_utils.py
backends/tosa/__init__.py
)

declare_mlir_python_sources(TorchMLIRPythonSources.Dialects
Expand Down
Loading

0 comments on commit e3c3852

Please sign in to comment.