-
Notifications
You must be signed in to change notification settings - Fork 516
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support to lower backend-legal operations as TOSA CustomOps
- Loading branch information
Showing
8 changed files
with
303 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
201 changes: 201 additions & 0 deletions
201
lib/Conversion/TorchToTosa/TorchBackendLegalToTosaCustom.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" | ||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" | ||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" | ||
|
||
#include "../PassDetail.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Dialect/Tosa/IR/TosaOps.h" | ||
#include "mlir/Dialect/Traits.h" | ||
#include "mlir/IR/Matchers.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" | ||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" | ||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" | ||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" | ||
#include "llvm/ADT/TypeSwitch.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::torch; | ||
using namespace mlir::torch::Torch; | ||
|
||
namespace { | ||
|
||
class ConvertBackendLegalAtenOp : public ConversionPattern { | ||
public: | ||
ArrayRef<std::string> customOps; | ||
|
||
ConvertBackendLegalAtenOp(TypeConverter &typeConverter, MLIRContext *context, | ||
ArrayRef<std::string> customOps) | ||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) { | ||
this->customOps = customOps; | ||
} | ||
|
||
Value matchCustomOpOperandValue(Value inputOperand, PatternRewriter &rewriter, | ||
Operation *backendLegalOp) const { | ||
// Get the Torch Op to find the constant attributes attached to the | ||
// backendLegalOp | ||
Value torchInputOperand = inputOperand; | ||
if (auto unrealizedCastOp = dyn_cast_or_null<UnrealizedConversionCastOp>( | ||
inputOperand.getDefiningOp())) { | ||
torchInputOperand = unrealizedCastOp.getInputs()[0]; | ||
} | ||
// Handle the special case where input operand is an argument to the module | ||
// function | ||
if (!torchInputOperand.getDefiningOp()) | ||
return inputOperand; | ||
|
||
return TypeSwitch<Operation *, Value>(torchInputOperand.getDefiningOp()) | ||
.Case<Torch::ConstantBoolOp>([&](Operation *boolOperand) -> Value { | ||
bool boolConstAttr; | ||
if (matchPattern(boolOperand, m_TorchConstantBool(&boolConstAttr))) { | ||
return tosa::getConstTensor<int64_t>(rewriter, backendLegalOp, | ||
boolConstAttr, {}) | ||
.value(); | ||
} | ||
return nullptr; | ||
}) | ||
.Case<Torch::ConstantIntOp>([&](Operation *intOperand) -> Value { | ||
int64_t intConstAttr; | ||
if (matchPattern(intOperand, m_TorchConstantInt(&intConstAttr))) { | ||
return tosa::getConstTensor<int64_t>(rewriter, backendLegalOp, | ||
intConstAttr, {}) | ||
.value(); | ||
} | ||
return nullptr; | ||
}) | ||
.Case<Torch::ConstantFloatOp>([&](Operation *floatOperand) -> Value { | ||
double floatConstAttr; | ||
if (matchPattern(floatOperand, | ||
m_TorchConstantFloat(&floatConstAttr))) { | ||
return tosa::getConstTensor<float>(rewriter, backendLegalOp, | ||
floatConstAttr, {}) | ||
.value(); | ||
} | ||
return nullptr; | ||
}) | ||
.Case<Torch::PrimListConstructOp>( | ||
[&](Operation *intListConstructOperand) -> Value { | ||
SmallVector<int64_t> intConstListAttr; | ||
if (matchPattern(intListConstructOperand, | ||
m_TorchConstantIntList(intConstListAttr))) { | ||
return tosa::getConstTensor<int64_t>( | ||
rewriter, backendLegalOp, intConstListAttr, | ||
{static_cast<int64_t>(intConstListAttr.size())}) | ||
.value(); | ||
} | ||
return nullptr; | ||
}) | ||
.Case<Torch::ConstantNoneOp>([&](Operation *noneOperand) -> Value { | ||
auto noneCustomOp = rewriter.create<tosa::CustomOp>( | ||
backendLegalOp->getLoc(), | ||
RankedTensorType::get({}, rewriter.getIntegerType(1)), | ||
rewriter.getStringAttr("torch.none"), ValueRange{}); | ||
return noneCustomOp.getResult(0); | ||
}) | ||
.Default([&](Operation *defaultOperand) { return inputOperand; }); | ||
} | ||
|
||
LogicalResult | ||
matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||
ConversionPatternRewriter &rewriter) const final { | ||
|
||
for (auto customOpName : customOps) { | ||
if (customOpName == op->getName().getStringRef()) { | ||
SmallVector<Value, 4> customOpInputOperands; | ||
|
||
for (auto operand : operands) { | ||
Value customOpInputOperand = | ||
matchCustomOpOperandValue(operand, rewriter, op); | ||
if (!customOpInputOperand) { | ||
return rewriter.notifyMatchFailure( | ||
op, | ||
"failed to match the constant operand of the backend-legal Op"); | ||
} | ||
customOpInputOperands.push_back(customOpInputOperand); | ||
} | ||
SmallVector<Type> customOpResultTypes; | ||
auto convertTypesResult = getTypeConverter()->convertTypes( | ||
op->getResultTypes(), customOpResultTypes); | ||
if (convertTypesResult.failed()) | ||
return failure(); | ||
auto tosaCustomOp = rewriter.create<tosa::CustomOp>( | ||
op->getLoc(), TypeRange{customOpResultTypes}, | ||
llvm::StringRef(op->getName().stripDialect()), | ||
ValueRange{customOpInputOperands}); | ||
// TODO: Use the 'config' StringAttr in the TOSA CustomOp interface | ||
// instead. This requires the LLVM version uplift to include new TOSA | ||
// CustomOp definition | ||
tosaCustomOp.getOperation()->setAttr( | ||
rewriter.getStringAttr("config"), | ||
rewriter.getStringAttr("torch_mlir")); | ||
rewriter.replaceOp(op, tosaCustomOp.getResults()); | ||
return success(); | ||
} | ||
} | ||
return failure(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
// ----------------------------------------------------------------------------- | ||
// TorchBackendLegalToTosaCustom Pass | ||
// ----------------------------------------------------------------------------- | ||
|
||
namespace { | ||
class ConvertTorchBackendLegalToTosaCustom | ||
: public ConvertTorchBackendLegalToTosaCustomBase< | ||
ConvertTorchBackendLegalToTosaCustom> { | ||
public: | ||
ConvertTorchBackendLegalToTosaCustom() = default; | ||
ConvertTorchBackendLegalToTosaCustom(ArrayRef<std::string> customOps) { | ||
this->customOps = customOps; | ||
} | ||
|
||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<tosa::TosaDialect>(); | ||
registry.insert<tensor::TensorDialect>(); | ||
registry.insert<arith::ArithDialect>(); | ||
TorchConversion::getBackendTypeConversionDependentDialects(registry); | ||
} | ||
|
||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
ConversionTarget target(*context); | ||
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect, | ||
arith::ArithDialect>(); | ||
|
||
TypeConverter typeConverter; | ||
typeConverter.addConversion([](Type type) { return type; }); | ||
TorchConversion::setupBackendTypeConversion(target, typeConverter); | ||
|
||
RewritePatternSet patterns(context); | ||
|
||
patterns.add<ConvertBackendLegalAtenOp>(typeConverter, context, customOps); | ||
|
||
for (std::string opName : customOps) { | ||
target.addIllegalOp(OperationName(opName, context)); | ||
} | ||
|
||
if (failed(applyPartialConversion(getOperation(), target, | ||
std::move(patterns)))) | ||
return signalPassFailure(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
std::unique_ptr<OperationPass<func::FuncOp>> | ||
mlir::torch::createConvertTorchBackendLegalToTosaCustomPass( | ||
ArrayRef<std::string> customOps) { | ||
return std::make_unique<ConvertTorchBackendLegalToTosaCustom>(customOps); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
// RUN: torch-mlir-opt <%s -convert-torch-backend-legal-to-tosa-custom="custom-ops=torch.aten._softmax" -split-input-file -verify-diagnostics | FileCheck %s | ||
|
||
// CHECK-LABEL: func.func @torch.aten._softmax( | ||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> { | ||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,4],f32> -> tensor<4x4xf32> | ||
// CHECK: %[[VAL_2:.*]] = torch.constant.int -1 | ||
// CHECK: %[[VAL_3:.*]] = torch.constant.bool false | ||
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<-1> : tensor<i64>} : () -> tensor<i64> | ||
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64> | ||
// CHECK: %[[VAL_6:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]], %[[VAL_5]]) {identifier = "torch.aten._softmax"} : (tensor<4x4xf32>, tensor<i64>, tensor<i64>) -> tensor<4x4xf32> | ||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x4xf32> -> !torch.vtensor<[4,4],f32> | ||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,4],f32> | ||
// CHECK: } | ||
func.func @torch.aten._softmax(%arg0: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> { | ||
%int-1 = torch.constant.int -1 | ||
%false = torch.constant.bool false | ||
%0 = torch.aten._softmax %arg0, %int-1, %false : !torch.vtensor<[4,4],f32>, !torch.int, !torch.bool -> !torch.vtensor<[4,4],f32> | ||
return %0 : !torch.vtensor<[4,4],f32> | ||
} |
bc3e7a5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// CHECK: %[[VAL_6:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]], %[[VAL_5]]) {config = "torch_mlir", identifier = "aten._softmax"} : (tensor<4x4xf32>, tensor<i64>, tensor<i64>) -> tensor<4x4xf32>