Skip to content

Commit

Permalink
[MLIR][TORCH] Add -convert-torch-to-tosa-custom pass for selective un…
Browse files Browse the repository at this point in the history
…decomposed ops
  • Loading branch information
AmosLewis authored and AmosLewis committed Nov 4, 2022
1 parent a897010 commit 3a0be8a
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 3 deletions.
32 changes: 32 additions & 0 deletions include/torch-mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,38 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
}

def ConvertTorchToTosaCustom : Pass<"convert-torch-to-tosa-custom", "func::FuncOp"> {
let summary = "Convert Torch ops to TOSA custom ops";
let description = [{
The purpose to use tosa::custom is handle complex ops when we donnot
want to decompose them into simple ops.
The aten op name will used to construct a StringAttr as the identifier attribute for tosa::CustomOp.
Each input arg from Aten Dialect has to be converted to a tensor of number values as the
operand of tosa::CustomOp op. After convert, use ValueRange/SmallVector to include
all operand as the final input operands for tosa::CustomOp.

Take softmax for example:
"aten.softmax.int"(%x,%dim): (tensor<2x3xf32>, int) -> tensor<2x3xf32>
Decompose : with -torch-decompose-complex-ops flag
%2 = "tosa.exp"(%x) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%3 = "tosa.reduce_sum"(%2) {axis = %dim : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) -> tensor<2x1xf32>
%5 = "tosa.mul"(%2, %4) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
No-Decompose: with convert-torch-to-tosa-custom="custom-ops=torch.aten.softmax.int"
"tosa.custom(%x){identifier = "softmax"}" : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x3xf32>
}];
let options = [
ListOption<"customOps", "custom-ops", "std::string",
"List of operation names that should be converted to tosa::custom",
"llvm::cl::ZeroOrMore">
];

let constructor = [{
mlir::torch::createConvertTorchToTosaCustomPass(
/*customOps=*/{})
}];
}

def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
let description = [{
Expand Down
3 changes: 3 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();

std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToTosaCustomPass(ArrayRef<std::string> customOps);
}
} // namespace mlir

Expand Down
10 changes: 9 additions & 1 deletion include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);

/// Creates a pipeline that lowers from the torch backend contract to the
/// TOSA backend contract.
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
struct TosaBackendPipelineOptions
: public PassPipelineOptions<TosaBackendPipelineOptions> {
ListOption<std::string> customOps{
*this, "custom-ops",
llvm::cl::desc("List of ops to be converted to tosa::CustomOp."),
llvm::cl::ZeroOrMore};
};
void createTorchBackendToTosaBackendPipeline(
OpPassManager &pm, const TosaBackendPipelineOptions &options);

// Do not register the torch-to-mhlo pipeline if mhlo target is disabled
#ifdef TORCH_MLIR_ENABLE_MHLO
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TorchToTosa/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_conversion_library(TorchMLIRTorchToTosa
TorchToTosa.cpp
TorchToTosaCustom.cpp
TosaLegalizeUtils.cpp
TosaLegalizeCommon.cpp

Expand Down
150 changes: 150 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosaCustom.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include <unordered_map>

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {

template <typename AtenOpT>
class ConvertSelectiveAtenOpToTosaCustom : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ValueRange adaptor_operands = adaptor.getOperands();
int num_operands = adaptor_operands.size();
std::vector<mlir::Value> inputs_vec;
for (int i = 0; i < num_operands; i++) {
auto operand = *op.getODSOperands(i).begin();
auto adaptor_operand_type = adaptor_operands[i].getType();

if (adaptor_operand_type
.isa<mlir::IntegerType>()) { // Torch::ConstantIntOp
int64_t operand_tosa;
if (!matchPattern(operand, m_TorchConstantInt(&operand_tosa)))
return rewriter.notifyMatchFailure(
op, "unimplemented: operand should be a torch.constant.int");
auto operand_tensor_int = tosa::getConstTensor<int64_t>(
rewriter, op.getOperation(), operand_tosa, {1});
inputs_vec.push_back(operand_tensor_int.value());
} else if (adaptor_operand_type
.isa<mlir::FloatType>()) { // Torch::ConstantFloatOp
double operand_tosa;
if (!matchPattern(operand, m_TorchConstantFloat(&operand_tosa)))
return rewriter.notifyMatchFailure(
op, "unimplemented: operand should be a torch.constant.float");
auto operand_tensor_float = tosa::getConstTensor<int64_t>(
rewriter, op.getOperation(), operand_tosa, {1});
inputs_vec.push_back(operand_tensor_float.value());
} else if (adaptor_operand_type
.isa<mlir::TensorType>()) { // Torch::ValueTensorType
inputs_vec.push_back(*adaptor.getODSOperands(i).begin());
} else {
// TODO Handle more types like !torch.list<...>, !torch.device,
// !torch.string, !torch.none, !torch.generator.
return rewriter.notifyMatchFailure(
op,
"unimplemented: inputs type. The input has to be int/float ");
}
}
// Create output type for tosa::CustomOp input
auto outType = this->getTypeConverter()->convertType(op.getType());

// Create operands for tosa::CustomOp
llvm::ArrayRef<mlir::Value> ref(inputs_vec.data(), inputs_vec.size());
ValueRange custom_inputs(ref);
rewriter.replaceOpWithNewOp<tosa::CustomOp>(
op, outType, op.getOperationName(), custom_inputs);
return success();
}
};

} // namespace

// -----------------------------------------------------------------------------
// TorchToTosaCustom Pass
// -----------------------------------------------------------------------------

namespace {
class ConvertTorchToTosaCustom
: public ConvertTorchToTosaCustomBase<ConvertTorchToTosaCustom> {
public:
ConvertTorchToTosaCustom() = default;
ConvertTorchToTosaCustom(ArrayRef<std::string> customOps) {
this->customOps = customOps;
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tosa::TosaDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}

void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
arith::ArithDialect>();

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);

RewritePatternSet patterns(context);

std::unordered_map<std::string, bool> customOpsMap;
for (auto key : customOps) {
customOpsMap[key] = true;
}

#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertSelectiveAtenOpToTosaCustom<AtenOp>>(typeConverter, \
context);
if (customOpsMap["torch.aten.softmax.int"]) {
INSERT_ATENOP_PATTERN(AtenSoftmaxIntOp);
}
if (customOpsMap["torch.aten.rsub.Scalar"]) {
INSERT_ATENOP_PATTERN(AtenRsubScalarOp);
}
#undef INSERT_ATENOP_PATTERN

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToTosaCustomPass(
ArrayRef<std::string> customOps) {
return std::make_unique<ConvertTorchToTosaCustom>(customOps);
}
8 changes: 6 additions & 2 deletions lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void mlir::torch::registerTorchConversionPasses() {
"contract.",
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);

mlir::PassPipelineRegistration<>(
mlir::PassPipelineRegistration<TorchConversion::TosaBackendPipelineOptions>(
"torch-backend-to-tosa-backend-pipeline",
"Pipeline lowering torch backend contract to TOSA backend "
"contract.",
Expand Down Expand Up @@ -96,7 +96,11 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
}

void TorchConversion::createTorchBackendToTosaBackendPipeline(
OpPassManager &pm) {
OpPassManager &pm, const TosaBackendPipelineOptions &options) {
if (!options.customOps.empty()) {
pm.addNestedPass<func::FuncOp>(
createConvertTorchToTosaCustomPass(options.customOps));
}
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
// Perform rank broadcasting so TosaToLinalg pass works
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
Expand Down
39 changes: 39 additions & 0 deletions test/Conversion/TorchToTosa/custom.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa-custom="custom-ops=torch.aten.softmax.int,torch.aten.rsub.Scalar" -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// CHECK: %[[VAL_6:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]], %[[VAL_5]]) {identifier = "torch.aten.softmax.int"} : (tensor<2x3xf32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2x3xf32>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[2,3],f32>
// CHECK: }
func.func @torch.aten.softmax.int$cst_dim(%t: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
%dtype = torch.constant.int 1
%dim = torch.constant.int 1
%ret = torch.aten.softmax.int %t, %dim, %dtype : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
return %ret : !torch.vtensor<[2,3],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<3> : tensor<1xi64>} : () -> tensor<1xi64>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// CHECK: %[[VAL_6:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]], %[[VAL_5]]) {identifier = "torch.aten.rsub.Scalar"} : (tensor<?x?xf32>, tensor<1xi64>, tensor<1xi64>) -> tensor<?x?xf32>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%other = torch.constant.float 3.123400e+00
%alpha = torch.constant.int 1
%0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.int -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}

0 comments on commit 3a0be8a

Please sign in to comment.