-
Notifications
You must be signed in to change notification settings - Fork 514
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MLIR][TORCH] Add -convert-torch-to-tosa-custom pass for selective un…
…decomposed ops
- Loading branch information
AmosLewis
authored and
AmosLewis
committed
Nov 4, 2022
1 parent
a897010
commit 3a0be8a
Showing
7 changed files
with
240 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" | ||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" | ||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" | ||
|
||
#include "../PassDetail.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Dialect/Tosa/IR/TosaOps.h" | ||
#include "mlir/Dialect/Traits.h" | ||
#include "mlir/IR/Matchers.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" | ||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" | ||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" | ||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" | ||
#include <unordered_map> | ||
|
||
using namespace mlir; | ||
using namespace mlir::torch; | ||
using namespace mlir::torch::Torch; | ||
|
||
namespace { | ||
|
||
template <typename AtenOpT> | ||
class ConvertSelectiveAtenOpToTosaCustom : public OpConversionPattern<AtenOpT> { | ||
public: | ||
using OpConversionPattern<AtenOpT>::OpConversionPattern; | ||
using OpAdaptor = typename AtenOpT::Adaptor; | ||
LogicalResult | ||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
ValueRange adaptor_operands = adaptor.getOperands(); | ||
int num_operands = adaptor_operands.size(); | ||
std::vector<mlir::Value> inputs_vec; | ||
for (int i = 0; i < num_operands; i++) { | ||
auto operand = *op.getODSOperands(i).begin(); | ||
auto adaptor_operand_type = adaptor_operands[i].getType(); | ||
|
||
if (adaptor_operand_type | ||
.isa<mlir::IntegerType>()) { // Torch::ConstantIntOp | ||
int64_t operand_tosa; | ||
if (!matchPattern(operand, m_TorchConstantInt(&operand_tosa))) | ||
return rewriter.notifyMatchFailure( | ||
op, "unimplemented: operand should be a torch.constant.int"); | ||
auto operand_tensor_int = tosa::getConstTensor<int64_t>( | ||
rewriter, op.getOperation(), operand_tosa, {1}); | ||
inputs_vec.push_back(operand_tensor_int.value()); | ||
} else if (adaptor_operand_type | ||
.isa<mlir::FloatType>()) { // Torch::ConstantFloatOp | ||
double operand_tosa; | ||
if (!matchPattern(operand, m_TorchConstantFloat(&operand_tosa))) | ||
return rewriter.notifyMatchFailure( | ||
op, "unimplemented: operand should be a torch.constant.float"); | ||
auto operand_tensor_float = tosa::getConstTensor<int64_t>( | ||
rewriter, op.getOperation(), operand_tosa, {1}); | ||
inputs_vec.push_back(operand_tensor_float.value()); | ||
} else if (adaptor_operand_type | ||
.isa<mlir::TensorType>()) { // Torch::ValueTensorType | ||
inputs_vec.push_back(*adaptor.getODSOperands(i).begin()); | ||
} else { | ||
// TODO Handle more types like !torch.list<...>, !torch.device, | ||
// !torch.string, !torch.none, !torch.generator. | ||
return rewriter.notifyMatchFailure( | ||
op, | ||
"unimplemented: inputs type. The input has to be int/float "); | ||
} | ||
} | ||
// Create output type for tosa::CustomOp input | ||
auto outType = this->getTypeConverter()->convertType(op.getType()); | ||
|
||
// Create operands for tosa::CustomOp | ||
llvm::ArrayRef<mlir::Value> ref(inputs_vec.data(), inputs_vec.size()); | ||
ValueRange custom_inputs(ref); | ||
rewriter.replaceOpWithNewOp<tosa::CustomOp>( | ||
op, outType, op.getOperationName(), custom_inputs); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
// ----------------------------------------------------------------------------- | ||
// TorchToTosaCustom Pass | ||
// ----------------------------------------------------------------------------- | ||
|
||
namespace { | ||
class ConvertTorchToTosaCustom | ||
: public ConvertTorchToTosaCustomBase<ConvertTorchToTosaCustom> { | ||
public: | ||
ConvertTorchToTosaCustom() = default; | ||
ConvertTorchToTosaCustom(ArrayRef<std::string> customOps) { | ||
this->customOps = customOps; | ||
} | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<tosa::TosaDialect>(); | ||
registry.insert<tensor::TensorDialect>(); | ||
registry.insert<arith::ArithDialect>(); | ||
TorchConversion::getBackendTypeConversionDependentDialects(registry); | ||
} | ||
|
||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
ConversionTarget target(*context); | ||
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect, | ||
arith::ArithDialect>(); | ||
|
||
TypeConverter typeConverter; | ||
typeConverter.addConversion([](Type type) { return type; }); | ||
TorchConversion::setupBackendTypeConversion(target, typeConverter); | ||
|
||
RewritePatternSet patterns(context); | ||
|
||
std::unordered_map<std::string, bool> customOpsMap; | ||
for (auto key : customOps) { | ||
customOpsMap[key] = true; | ||
} | ||
|
||
#define INSERT_ATENOP_PATTERN(AtenOp) \ | ||
target.addIllegalOp<AtenOp>(); \ | ||
patterns.add<ConvertSelectiveAtenOpToTosaCustom<AtenOp>>(typeConverter, \ | ||
context); | ||
if (customOpsMap["torch.aten.softmax.int"]) { | ||
INSERT_ATENOP_PATTERN(AtenSoftmaxIntOp); | ||
} | ||
if (customOpsMap["torch.aten.rsub.Scalar"]) { | ||
INSERT_ATENOP_PATTERN(AtenRsubScalarOp); | ||
} | ||
#undef INSERT_ATENOP_PATTERN | ||
|
||
if (failed(applyPartialConversion(getOperation(), target, | ||
std::move(patterns)))) | ||
return signalPassFailure(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
std::unique_ptr<OperationPass<func::FuncOp>> | ||
mlir::torch::createConvertTorchToTosaCustomPass( | ||
ArrayRef<std::string> customOps) { | ||
return std::make_unique<ConvertTorchToTosaCustom>(customOps); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa-custom="custom-ops=torch.aten.softmax.int,torch.aten.rsub.Scalar" -split-input-file -verify-diagnostics | FileCheck %s | ||
|
||
// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim( | ||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> { | ||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> | ||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 | ||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 | ||
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> | ||
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> | ||
// CHECK: %[[VAL_6:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]], %[[VAL_5]]) {identifier = "torch.aten.softmax.int"} : (tensor<2x3xf32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2x3xf32> | ||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> | ||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[2,3],f32> | ||
// CHECK: } | ||
func.func @torch.aten.softmax.int$cst_dim(%t: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> { | ||
%dtype = torch.constant.int 1 | ||
%dim = torch.constant.int 1 | ||
%ret = torch.aten.softmax.int %t, %dim, %dtype : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> | ||
return %ret : !torch.vtensor<[2,3],f32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$basic( | ||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { | ||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> | ||
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 | ||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 | ||
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<3> : tensor<1xi64>} : () -> tensor<1xi64> | ||
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> | ||
// CHECK: %[[VAL_6:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]], %[[VAL_5]]) {identifier = "torch.aten.rsub.Scalar"} : (tensor<?x?xf32>, tensor<1xi64>, tensor<1xi64>) -> tensor<?x?xf32> | ||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> | ||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> | ||
// CHECK: } | ||
func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { | ||
%other = torch.constant.float 3.123400e+00 | ||
%alpha = torch.constant.int 1 | ||
%0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.int -> !torch.vtensor<[?,?],f32> | ||
return %0 : !torch.vtensor<[?,?],f32> | ||
} |