-
Notifications
You must be signed in to change notification settings - Fork 514
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MLIR][TORCH] Add -convert-torch-to-tosa-custom pass for undecomposed…
… torch.aten.softmax.int op
- Loading branch information
AmosLewis
authored and
AmosLewis
committed
Nov 1, 2022
1 parent
a897010
commit 62233fd
Showing
7 changed files
with
207 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" | ||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" | ||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" | ||
|
||
#include "../PassDetail.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Dialect/Tosa/IR/TosaOps.h" | ||
#include "mlir/Dialect/Traits.h" | ||
#include "mlir/IR/Matchers.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" | ||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" | ||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" | ||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" | ||
#include <unordered_map> | ||
|
||
using namespace mlir; | ||
using namespace mlir::torch; | ||
using namespace mlir::torch::Torch; | ||
|
||
namespace { | ||
|
||
// This defines a template to construct ops whose legalizations are | ||
// specialized. | ||
template <typename AtenOpT> | ||
class ConvertAtenOp : public OpConversionPattern<AtenOpT> { | ||
public: | ||
using OpConversionPattern<AtenOpT>::OpConversionPattern; | ||
using OpAdaptor = typename AtenOpT::Adaptor; | ||
LogicalResult | ||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override; | ||
}; | ||
|
||
template <> | ||
LogicalResult ConvertAtenOp<AtenSoftmaxIntOp>::matchAndRewrite( | ||
AtenSoftmaxIntOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const { | ||
// Math: exp(%x) / sum(exp(%x), %dim) | ||
// Torch format: | ||
// "aten.softmax.int"(%x,%dim): (tensor<2x3xf32>, int) -> tensor<2x3xf32> | ||
// Decompose tosa format: with -torch-decompose-complex-ops flag | ||
// https://gist.github.com/AmosLewis/e668c3bfd2472e9f9f045e012362d831 | ||
// %2 = "tosa.exp"(%x) : (tensor<2x3xf32>) -> tensor<2x3xf32> | ||
// %3 = "tosa.reduce_sum"(%2) {axis = %dim : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32> | ||
// %4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) -> tensor<2x1xf32> | ||
// %5 = "tosa.mul"(%2, %4) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32> | ||
// No-Decompose TOSA format: without -torch-decompose-complex-ops flag | ||
// "tosa.custom(%x){identifier = "softmax"}" : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x3xf32> | ||
|
||
// Check AtenSoftmaxIntOp first input is a tensor type. | ||
auto selfType = adaptor.self().getType().dyn_cast<TensorType>(); | ||
if (!selfType) | ||
return rewriter.notifyMatchFailure( | ||
op, "Only tensor types input are currently supported"); | ||
|
||
// Get the dim int64_t type value from AtenSoftmaxIntOp second input, | ||
// type need to convert from mlir::TypedValue<::mlir::torch::Torch::IntType> | ||
int64_t dim; | ||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) | ||
return rewriter.notifyMatchFailure( | ||
op, "unimplemented: value `dtype` should be a torch.constant.int"); | ||
|
||
// Create output type for tosa::CustomOp input | ||
auto outType = getTypeConverter()->convertType(op.getType()); | ||
// Create name attribute and multi-args for tosa::CustomOp input | ||
StringAttr nameValueAttr= rewriter.getStringAttr("softmax"); | ||
auto dimTensor = tosa::getConstTensor<int64_t>( | ||
rewriter, op.getOperation(), dim, {1}); | ||
SmallVector<Value,2> inputOperands{adaptor.self(), dimTensor.value()}; | ||
|
||
// TODO unportable target hardware implementation of exp(%x) / sum(exp(%x), %dim) | ||
rewriter.replaceOpWithNewOp<tosa::CustomOp>(op, outType, nameValueAttr, | ||
inputOperands); | ||
return success(); | ||
} | ||
|
||
|
||
} // namespace | ||
|
||
// ----------------------------------------------------------------------------- | ||
// TorchToTosaCustom Pass | ||
// ----------------------------------------------------------------------------- | ||
|
||
namespace { | ||
class ConvertTorchToTosaCustom : public ConvertTorchToTosaCustomBase<ConvertTorchToTosaCustom> { | ||
public: | ||
ConvertTorchToTosaCustom() = default; | ||
ConvertTorchToTosaCustom(ArrayRef<std::string> customOps) { | ||
this->customOps = customOps; | ||
} | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<tosa::TosaDialect>(); | ||
registry.insert<tensor::TensorDialect>(); | ||
registry.insert<arith::ArithDialect>(); | ||
TorchConversion::getBackendTypeConversionDependentDialects(registry); | ||
} | ||
|
||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
ConversionTarget target(*context); | ||
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect, | ||
arith::ArithDialect>(); | ||
|
||
TypeConverter typeConverter; | ||
typeConverter.addConversion([](Type type) { return type; }); | ||
TorchConversion::setupBackendTypeConversion(target, typeConverter); | ||
|
||
RewritePatternSet patterns(context); | ||
|
||
std::unordered_map<std::string, bool> customOpsMap; | ||
for(auto key: customOps){ | ||
customOpsMap[key] = true; | ||
} | ||
|
||
#define INSERT_ATENOP_PATTERN(AtenOp) \ | ||
target.addIllegalOp<AtenOp>(); \ | ||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context); | ||
if(customOpsMap["torch.aten.softmax.int"]){ | ||
INSERT_ATENOP_PATTERN(AtenSoftmaxIntOp); | ||
} | ||
#undef INSERT_ATENOP_PATTERN | ||
|
||
if (failed(applyPartialConversion(getOperation(), target, | ||
std::move(patterns)))) | ||
return signalPassFailure(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
std::unique_ptr<OperationPass<func::FuncOp>> | ||
mlir::torch::createConvertTorchToTosaCustomPass(ArrayRef<std::string> customOps) { | ||
return std::make_unique<ConvertTorchToTosaCustom>(customOps); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa-custom="custom-ops=torch.aten.softmax.int" -split-input-file -verify-diagnostics | FileCheck %s | ||
|
||
// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim( | ||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> { | ||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> | ||
// CHECK: %[[VAL_2:.*]] = torch.constant.none | ||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 | ||
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> | ||
// CHECK: %[[VAL_5:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]]) {identifier = "softmax"} : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x3xf32> | ||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> | ||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[2,3],f32> | ||
// CHECK: } | ||
func.func @torch.aten.softmax.int$cst_dim(%t: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> { | ||
%none = torch.constant.none | ||
%dim = torch.constant.int 1 | ||
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.vtensor<[2,3],f32>, !torch.int, !torch.none -> !torch.vtensor<[2,3],f32> | ||
return %ret : !torch.vtensor<[2,3],f32> | ||
} |