Skip to content

Commit

Permalink
Add support to lower backend-legal operations as TOSA CustomOps
Browse files Browse the repository at this point in the history
  • Loading branch information
Svoch committed Nov 8, 2022
1 parent 9a73b9e commit bc3e7a5
Show file tree
Hide file tree
Showing 8 changed files with 303 additions and 6 deletions.
53 changes: 53 additions & 0 deletions include/torch-mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,59 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
}

def ConvertTorchBackendLegalToTosaCustom
: Pass<"convert-torch-backend-legal-to-tosa-custom", "func::FuncOp"> {
let summary = "Convert Torch backend-legal operations to TOSA CustomOps.";
let description = [{
This pass extends the selective decomposition patterns to TOSA backend,
such that it will convert operations marked to custom-ops by the user
for the TOSA backend to CustomOps in TOSA dialect.
The backend-custom operations will not be decomposed or lowered through
the existing Torch to TOSA conversion patterns, even if such patterns
exist in the subsequent Torch to TOSA conversion pass; instead a TOSA
CustomOp will be created with the specifications below:
1. The 'identifier' attribute in the CustomOp will be the operation
name of the corresponding ATen operation.
2. The 'config' attribute in the CustomOp will be the specification
of the compilation pipeline, i.e. "torch_mlir"
3. All inputs to the ATen operation will be converted into legal TOSA
operations. There will be no distiction between inputs and constant
attributes; all will be treated as inputs to the CustomOp. It is up
to the consuming backend to deduce them based on the ATen operation
semantics.
4. Since TOSA conversion pattern of Torch operations are responsible
to consume and convet the constant attributes (such as 'axis' for a
reduction operation), the TOSA CustomOp conversion will also match
and rewrite these attributes as TOSA ConstOps as well. Specifically:
- torch.constant.bool -> tosa.ConstOp (of tensor type i64)
- torch.constant.int -> tosa.ConstOp (of tensor type i64)
- torch.constant.float -> tosa.ConstOp (of tensor type f32)
- torch.prim.ListConstruct -> tosa.ConstOp (of tensor type i64)
- torch.constant.none -> tosa.CustomOp (of tensor type i1)
The 'identifier' attribute of this
CustomOp is 'torch.constant.none'
All other Torch ATen operations will be lowered to TOSA by the Torch
to TOSA conversion pass after this one.
TODO: Extend the contract for other Torch constant operations such as:
- torch.constant.device
- torch.constant.number
- torch.constant.str
5. The order of the input operands of the backend-legal Torch operation
preserved and the TOSA CustomOp will have the same order.
TODO: Update this pass to populate the 'config' attribute of the TOSA
CustomOp to establish a proper operand mapping scheme between
the backend-legal Torch operation and the TOSA CustomOp.
(See the commit at https://github.com/llvm/llvm-project/commit/d94ee70f4f01e4d9eec49e02eff57a5655618401)
}];
let options = [ListOption<
"customOps", "custom-ops", "std::string",
"List of operations considered backend-legal and should be converted "
"to converted to CustomOps in TOSA dialect.",
"llvm::cl::ZeroOrMore">];
let constructor =
"mlir::torch::createConvertTorchBackendLegalToTosaCustomPass(/*customOps=*/{})";
}

def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
let summary = "Convert Torch ops to TOSA ops";
let description = [{
Expand Down
5 changes: 4 additions & 1 deletion include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@

namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchBackendLegalToTosaCustomPass(ArrayRef<std::string> customOps);

std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
}
} // namespace torch
} // namespace mlir

#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
11 changes: 10 additions & 1 deletion include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,16 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);

/// Creates a pipeline that lowers from the torch backend contract to the
/// TOSA backend contract.
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
struct TosaBackendPipelineOptions
: public PassPipelineOptions<TosaBackendPipelineOptions> {
ListOption<std::string> customOps{
*this, "custom-ops",
llvm::cl::desc("List of operations considered backend-legal and "
"should be converted to CustomOps in TOSA dialect.")};
};

void createTorchBackendToTosaBackendPipeline(
OpPassManager &pm, const TosaBackendPipelineOptions &options);

// Do not register the torch-to-mhlo pipeline if mhlo target is disabled
#ifdef TORCH_MLIR_ENABLE_MHLO
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TorchToTosa/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_conversion_library(TorchMLIRTorchToTosa
TorchBackendLegalToTosaCustom.cpp
TorchToTosa.cpp
TosaLegalizeUtils.cpp
TosaLegalizeCommon.cpp
Expand Down
201 changes: 201 additions & 0 deletions lib/Conversion/TorchToTosa/TorchBackendLegalToTosaCustom.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {

class ConvertBackendLegalAtenOp : public ConversionPattern {
public:
ArrayRef<std::string> customOps;

ConvertBackendLegalAtenOp(TypeConverter &typeConverter, MLIRContext *context,
ArrayRef<std::string> customOps)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {
this->customOps = customOps;
}

Value matchCustomOpOperandValue(Value inputOperand, PatternRewriter &rewriter,
Operation *backendLegalOp) const {
// Get the Torch Op to find the constant attributes attached to the
// backendLegalOp
Value torchInputOperand = inputOperand;
if (auto unrealizedCastOp = dyn_cast_or_null<UnrealizedConversionCastOp>(
inputOperand.getDefiningOp())) {
torchInputOperand = unrealizedCastOp.getInputs()[0];
}
// Handle the special case where input operand is an argument to the module
// function
if (!torchInputOperand.getDefiningOp())
return inputOperand;

return TypeSwitch<Operation *, Value>(torchInputOperand.getDefiningOp())
.Case<Torch::ConstantBoolOp>([&](Operation *boolOperand) -> Value {
bool boolConstAttr;
if (matchPattern(boolOperand, m_TorchConstantBool(&boolConstAttr))) {
return tosa::getConstTensor<int64_t>(rewriter, backendLegalOp,
boolConstAttr, {})
.value();
}
return nullptr;
})
.Case<Torch::ConstantIntOp>([&](Operation *intOperand) -> Value {
int64_t intConstAttr;
if (matchPattern(intOperand, m_TorchConstantInt(&intConstAttr))) {
return tosa::getConstTensor<int64_t>(rewriter, backendLegalOp,
intConstAttr, {})
.value();
}
return nullptr;
})
.Case<Torch::ConstantFloatOp>([&](Operation *floatOperand) -> Value {
double floatConstAttr;
if (matchPattern(floatOperand,
m_TorchConstantFloat(&floatConstAttr))) {
return tosa::getConstTensor<float>(rewriter, backendLegalOp,
floatConstAttr, {})
.value();
}
return nullptr;
})
.Case<Torch::PrimListConstructOp>(
[&](Operation *intListConstructOperand) -> Value {
SmallVector<int64_t> intConstListAttr;
if (matchPattern(intListConstructOperand,
m_TorchConstantIntList(intConstListAttr))) {
return tosa::getConstTensor<int64_t>(
rewriter, backendLegalOp, intConstListAttr,
{static_cast<int64_t>(intConstListAttr.size())})
.value();
}
return nullptr;
})
.Case<Torch::ConstantNoneOp>([&](Operation *noneOperand) -> Value {
auto noneCustomOp = rewriter.create<tosa::CustomOp>(
backendLegalOp->getLoc(),
RankedTensorType::get({}, rewriter.getIntegerType(1)),
rewriter.getStringAttr("torch.none"), ValueRange{});
return noneCustomOp.getResult(0);
})
.Default([&](Operation *defaultOperand) { return inputOperand; });
}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {

for (auto customOpName : customOps) {
if (customOpName == op->getName().getStringRef()) {
SmallVector<Value, 4> customOpInputOperands;

for (auto operand : operands) {
Value customOpInputOperand =
matchCustomOpOperandValue(operand, rewriter, op);
if (!customOpInputOperand) {
return rewriter.notifyMatchFailure(
op,
"failed to match the constant operand of the backend-legal Op");
}
customOpInputOperands.push_back(customOpInputOperand);
}
SmallVector<Type> customOpResultTypes;
auto convertTypesResult = getTypeConverter()->convertTypes(
op->getResultTypes(), customOpResultTypes);
if (convertTypesResult.failed())
return failure();
auto tosaCustomOp = rewriter.create<tosa::CustomOp>(
op->getLoc(), TypeRange{customOpResultTypes},
llvm::StringRef(op->getName().stripDialect()),
ValueRange{customOpInputOperands});
// TODO: Use the 'config' StringAttr in the TOSA CustomOp interface
// instead. This requires the LLVM version uplift to include new TOSA
// CustomOp definition
tosaCustomOp.getOperation()->setAttr(
rewriter.getStringAttr("config"),
rewriter.getStringAttr("torch_mlir"));
rewriter.replaceOp(op, tosaCustomOp.getResults());
return success();
}
}
return failure();
}
};

} // namespace

// -----------------------------------------------------------------------------
// TorchBackendLegalToTosaCustom Pass
// -----------------------------------------------------------------------------

namespace {
class ConvertTorchBackendLegalToTosaCustom
: public ConvertTorchBackendLegalToTosaCustomBase<
ConvertTorchBackendLegalToTosaCustom> {
public:
ConvertTorchBackendLegalToTosaCustom() = default;
ConvertTorchBackendLegalToTosaCustom(ArrayRef<std::string> customOps) {
this->customOps = customOps;
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tosa::TosaDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}

void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
arith::ArithDialect>();

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);

RewritePatternSet patterns(context);

patterns.add<ConvertBackendLegalAtenOp>(typeConverter, context, customOps);

for (std::string opName : customOps) {
target.addIllegalOp(OperationName(opName, context));
}

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchBackendLegalToTosaCustomPass(
ArrayRef<std::string> customOps) {
return std::make_unique<ConvertTorchBackendLegalToTosaCustom>(customOps);
}
6 changes: 4 additions & 2 deletions lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void mlir::torch::registerTorchConversionPasses() {
"contract.",
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);

mlir::PassPipelineRegistration<>(
mlir::PassPipelineRegistration<TorchConversion::TosaBackendPipelineOptions>(
"torch-backend-to-tosa-backend-pipeline",
"Pipeline lowering torch backend contract to TOSA backend "
"contract.",
Expand Down Expand Up @@ -96,7 +96,9 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
}

void TorchConversion::createTorchBackendToTosaBackendPipeline(
OpPassManager &pm) {
OpPassManager &pm, const TosaBackendPipelineOptions &options) {
pm.addNestedPass<func::FuncOp>(
createConvertTorchBackendLegalToTosaCustomPass(options.customOps));
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
// Perform rank broadcasting so TosaToLinalg pass works
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
Expand Down
13 changes: 11 additions & 2 deletions python/torch_mlir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def compile(model: torch.nn.Module,
use_tracing: bool = False,
ignore_traced_shapes = False,
backend_legal_ops: Optional[Sequence[str]] = None,
backend_custom_ops: Optional[Sequence[str]] = None,
verbose: bool = False):
"""Convert a PyTorch model to MLIR.
Expand All @@ -165,6 +166,9 @@ def compile(model: torch.nn.Module,
backend_legal_ops: A list of ops that should be considered legal for
the backend. An op that is considered legal will not be decomposed.
This option is only valid with the `"torch"` output type.
backend_custom_ops: A list of ops to be converted to the custom ops in
the backend dialect. Currently, only this option is only used by
the TOSA backend.
verbose: If true, print extra information about the conversion.
Returns:
Expand Down Expand Up @@ -192,7 +196,7 @@ def compile(model: torch.nn.Module,
# tensor to a list of a single tensor to make the API more ergonomic.
if isinstance(example_args, (torch.Tensor, TensorPlaceholder)):
example_args = (example_args,)

# If users passed in anything other than tensors or a list of tensors (e.g.
# a dictionary), we can't handle it.
if not isinstance(example_args, Sequence):
Expand Down Expand Up @@ -285,10 +289,15 @@ def compile(model: torch.nn.Module,
if output_type == OutputType.TORCH:
return mb.module

if backend_custom_ops is not None:
backend_custom_ops = list(sorted(set(backend_custom_ops)))
else:
backend_custom_ops = []
backend_option_string = "{custom-ops=" + ",".join(backend_custom_ops) + "}"
if output_type == OutputType.TOSA:
run_pipeline_with_repro_report(
mb.module,
"torch-backend-to-tosa-backend-pipeline",
f"torch-backend-to-tosa-backend-pipeline{backend_option_string}",
"Lowering Torch Backend IR -> TOSA Backend IR")
if verbose:
print("\n====================")
Expand Down
19 changes: 19 additions & 0 deletions test/Conversion/TorchToTosa/selective_decomposition.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: torch-mlir-opt <%s -convert-torch-backend-legal-to-tosa-custom="custom-ops=torch.aten._softmax" -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func.func @torch.aten._softmax(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,4],f32> -> tensor<4x4xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int -1
// CHECK: %[[VAL_3:.*]] = torch.constant.bool false
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<-1> : tensor<i64>} : () -> tensor<i64>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
// CHECK: %[[VAL_6:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]], %[[VAL_5]]) {identifier = "torch.aten._softmax"} : (tensor<4x4xf32>, tensor<i64>, tensor<i64>) -> tensor<4x4xf32>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x4xf32> -> !torch.vtensor<[4,4],f32>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,4],f32>
// CHECK: }
func.func @torch.aten._softmax(%arg0: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> {
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%0 = torch.aten._softmax %arg0, %int-1, %false : !torch.vtensor<[4,4],f32>, !torch.int, !torch.bool -> !torch.vtensor<[4,4],f32>
return %0 : !torch.vtensor<[4,4],f32>
}

1 comment on commit bc3e7a5

@AmosLewis
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// CHECK: %[[VAL_6:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]], %[[VAL_5]]) {config = "torch_mlir", identifier = "aten._softmax"} : (tensor<4x4xf32>, tensor<i64>, tensor<i64>) -> tensor<4x4xf32>

Please sign in to comment.