Skip to content

Commit

Permalink
[MLIR][TORCH] Add -convert-torch-to-tosa-custom pass for undecomposed…
Browse files Browse the repository at this point in the history
… torch.aten.softmax.int op
  • Loading branch information
AmosLewis authored and AmosLewis committed Nov 1, 2022
1 parent a897010 commit 62233fd
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 3 deletions.
26 changes: 26 additions & 0 deletions include/torch-mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,32 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
}

def ConvertTorchToTosaCustom : Pass<"convert-torch-to-tosa-custom", "func::FuncOp"> {
let summary = "Convert Torch ops to TOSA custom ops";
let description = [{
The purpose to use tosa::custom is handle complex ops when we donnot
want to decompose them into simple ops. Take softmax for example:
"aten.softmax.int"(%x,%dim): (tensor<2x3xf32>, int) -> tensor<2x3xf32>
Decompose : with -torch-decompose-complex-ops flag
%2 = "tosa.exp"(%x) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%3 = "tosa.reduce_sum"(%2) {axis = %dim : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
%4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) -> tensor<2x1xf32>
%5 = "tosa.mul"(%2, %4) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
No-Decompose: with convert-torch-to-tosa-custom="custom-ops=torch.aten.softmax.int"
"tosa.custom(%x){identifier = "softmax"}" : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x3xf32>
}];
let options = [
ListOption<"customOps", "custom-ops", "std::string",
"List of operation names that should be converted to tosa::custom",
"llvm::cl::ZeroOrMore">
];

let constructor = [{
mlir::torch::createConvertTorchToTosaCustomPass(
/*customOps=*/{})
}];
}

def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
let description = [{
Expand Down
4 changes: 4 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();

std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaCustomPass(
ArrayRef<std::string> customOps
);
}
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);

/// Creates a pipeline that lowers from the torch backend contract to the
/// TOSA backend contract.
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
struct TosaBackendPipelineOptions
: public PassPipelineOptions<TosaBackendPipelineOptions> {
ListOption<std::string> customOps{
*this, "custom-ops",
llvm::cl::desc("List of ops to be converted to tosa::CustomOp."),
llvm::cl::ZeroOrMore};
};
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm, const TosaBackendPipelineOptions &options);

// Do not register the torch-to-mhlo pipeline if mhlo target is disabled
#ifdef TORCH_MLIR_ENABLE_MHLO
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TorchToTosa/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_conversion_library(TorchMLIRTorchToTosa
TorchToTosa.cpp
TorchToTosaCustom.cpp
TosaLegalizeUtils.cpp
TosaLegalizeCommon.cpp

Expand Down
144 changes: 144 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosaCustom.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include <unordered_map>

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {

// This defines a template to construct ops whose legalizations are
// specialized.
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

template <>
LogicalResult ConvertAtenOp<AtenSoftmaxIntOp>::matchAndRewrite(
AtenSoftmaxIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Math: exp(%x) / sum(exp(%x), %dim)
// Torch format:
// "aten.softmax.int"(%x,%dim): (tensor<2x3xf32>, int) -> tensor<2x3xf32>
// Decompose tosa format: with -torch-decompose-complex-ops flag
// https://gist.github.com/AmosLewis/e668c3bfd2472e9f9f045e012362d831
// %2 = "tosa.exp"(%x) : (tensor<2x3xf32>) -> tensor<2x3xf32>
// %3 = "tosa.reduce_sum"(%2) {axis = %dim : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32>
// %4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) -> tensor<2x1xf32>
// %5 = "tosa.mul"(%2, %4) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
// No-Decompose TOSA format: without -torch-decompose-complex-ops flag
// "tosa.custom(%x){identifier = "softmax"}" : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x3xf32>

// Check AtenSoftmaxIntOp first input is a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
if (!selfType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");

// Get the dim int64_t type value from AtenSoftmaxIntOp second input,
// type need to convert from mlir::TypedValue<::mlir::torch::Torch::IntType>
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "unimplemented: value `dtype` should be a torch.constant.int");

// Create output type for tosa::CustomOp input
auto outType = getTypeConverter()->convertType(op.getType());
// Create name attribute and multi-args for tosa::CustomOp input
StringAttr nameValueAttr= rewriter.getStringAttr("softmax");
auto dimTensor = tosa::getConstTensor<int64_t>(
rewriter, op.getOperation(), dim, {1});
SmallVector<Value,2> inputOperands{adaptor.self(), dimTensor.value()};

// TODO unportable target hardware implementation of exp(%x) / sum(exp(%x), %dim)
rewriter.replaceOpWithNewOp<tosa::CustomOp>(op, outType, nameValueAttr,
inputOperands);
return success();
}


} // namespace

// -----------------------------------------------------------------------------
// TorchToTosaCustom Pass
// -----------------------------------------------------------------------------

namespace {
class ConvertTorchToTosaCustom : public ConvertTorchToTosaCustomBase<ConvertTorchToTosaCustom> {
public:
ConvertTorchToTosaCustom() = default;
ConvertTorchToTosaCustom(ArrayRef<std::string> customOps) {
this->customOps = customOps;
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tosa::TosaDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}

void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
arith::ArithDialect>();

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);

RewritePatternSet patterns(context);

std::unordered_map<std::string, bool> customOpsMap;
for(auto key: customOps){
customOpsMap[key] = true;
}

#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
if(customOpsMap["torch.aten.softmax.int"]){
INSERT_ATENOP_PATTERN(AtenSoftmaxIntOp);
}
#undef INSERT_ATENOP_PATTERN

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToTosaCustomPass(ArrayRef<std::string> customOps) {
return std::make_unique<ConvertTorchToTosaCustom>(customOps);
}
8 changes: 6 additions & 2 deletions lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void mlir::torch::registerTorchConversionPasses() {
"contract.",
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);

mlir::PassPipelineRegistration<>(
mlir::PassPipelineRegistration<TorchConversion::TosaBackendPipelineOptions>(
"torch-backend-to-tosa-backend-pipeline",
"Pipeline lowering torch backend contract to TOSA backend "
"contract.",
Expand Down Expand Up @@ -96,7 +96,11 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
}

void TorchConversion::createTorchBackendToTosaBackendPipeline(
OpPassManager &pm) {
OpPassManager &pm, const TosaBackendPipelineOptions &options) {
if(!options.customOps.empty()){
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaCustomPass(
options.customOps));
}
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
// Perform rank broadcasting so TosaToLinalg pass works
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
Expand Down
18 changes: 18 additions & 0 deletions test/Conversion/TorchToTosa/custom.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa-custom="custom-ops=torch.aten.softmax.int" -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.none
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// CHECK: %[[VAL_5:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]]) {identifier = "softmax"} : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x3xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[2,3],f32>
// CHECK: }
func.func @torch.aten.softmax.int$cst_dim(%t: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
%none = torch.constant.none
%dim = torch.constant.int 1
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.vtensor<[2,3],f32>, !torch.int, !torch.none -> !torch.vtensor<[2,3],f32>
return %ret : !torch.vtensor<[2,3],f32>
}

0 comments on commit 62233fd

Please sign in to comment.