Skip to content

Commit

Permalink
[MLIR][TORCH] Add Pass for custom softmax ops
Browse files Browse the repository at this point in the history
1. Add an option "customOps" for convert-torch-to-tosa
2. Enable multi input arg for tosa::custom instead use multi attributes
  • Loading branch information
AmosLewis authored and AmosLewis committed Oct 27, 2022
1 parent 227c114 commit 8d0c781
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 39 deletions.
11 changes: 10 additions & 1 deletion include/torch-mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,16 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
This pass assumes that TOSA ops are responsible for emitting error
guards in case of shape mismatches.
}];
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
let options = [
ListOption<"customOps", "custom-ops", "std::string",
"List of operation names that should be converted to tosa::custom",
"llvm::cl::ZeroOrMore">
];

let constructor = [{
mlir::torch::createConvertTorchToTosaPass(
/*customOps=*/{})
}];
}

def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
Expand Down
4 changes: 3 additions & 1 deletion include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass(
ArrayRef<std::string> customOps
);
}
} // namespace mlir

Expand Down
14 changes: 13 additions & 1 deletion include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,19 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);

/// Creates a pipeline that lowers from the torch backend contract to the
/// TOSA backend contract.
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
struct TosaBackendPipelineOptions
: public PassPipelineOptions<TosaBackendPipelineOptions> {
// If this option is true, custom complex operations.
// If this option is false, skip decomposition of complex operations.
Option<bool> custom{*this, "custom-ops",
llvm::cl::desc("custom complex operations."),
llvm::cl::init(true)};
ListOption<std::string> customOps{
*this, "custom-ops",
llvm::cl::desc("List of ops to be converted to the backend."),
llvm::cl::ZeroOrMore};
};
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm, const TosaBackendPipelineOptions &options);

// Do not register the torch-to-mhlo pipeline if mhlo target is disabled
#ifdef TORCH_MLIR_ENABLE_MHLO
Expand Down
32 changes: 16 additions & 16 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3042,7 +3042,7 @@ LogicalResult ConvertAtenOp<AtenSoftmaxIntOp>::matchAndRewrite(
// %4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) -> tensor<2x1xf32>
// %5 = "tosa.mul"(%2, %4) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
// No-Decompose TOSA format: without -torch-decompose-complex-ops flag
// "tosa.custom(%x){dim = 1 : i64, identifier = "softmax"}" : (tensor<2x3xf32>) -> tensor<2x3xf32>
// "tosa.custom(%x){identifier = "softmax"}" : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x3xf32>

// Check AtenSoftmaxIntOp first input is a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
Expand All @@ -3059,21 +3059,15 @@ LogicalResult ConvertAtenOp<AtenSoftmaxIntOp>::matchAndRewrite(

// Create output type for tosa::CustomOp input
auto outType = getTypeConverter()->convertType(op.getType());

// Create attributes for tosa::CustomOp input
// example: {dim = 1 : i64, identifier = "softmax"}
StringAttr nameIdAttr= rewriter.getStringAttr("identifier");
// Create name attribute and multi-args for tosa::CustomOp input
StringAttr nameValueAttr= rewriter.getStringAttr("softmax");
StringAttr dimIdAttr= rewriter.getStringAttr("dim");
IntegerAttr dimValueAttr = rewriter.getI64IntegerAttr(dim);
mlir::NamedAttribute nameAttr = mlir::NamedAttribute(nameIdAttr, nameValueAttr);
mlir::NamedAttribute dimAttr = mlir::NamedAttribute(dimIdAttr, dimValueAttr);
llvm::ArrayRef<mlir::NamedAttribute> custom_attributes{nameAttr, dimAttr};
auto dimTensor = tosa::getConstTensor<int64_t>(
rewriter, op.getOperation(), dim, {1});
SmallVector<Value,2> inputOperands{adaptor.self(), dimTensor.value()};

// TODO unportable target hardware implementation of exp(%x) / sum(exp(%x), %dim)
rewriter.replaceOpWithNewOp<tosa::CustomOp>(op, outType, adaptor.self(),
custom_attributes);

rewriter.replaceOpWithNewOp<tosa::CustomOp>(op, outType, nameValueAttr,
inputOperands);
return success();
}

Expand Down Expand Up @@ -3708,6 +3702,10 @@ class ConvertAtenCloneOp : public OpConversionPattern<AtenOpT> {
namespace {
class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
public:
ConvertTorchToTosa() = default;
ConvertTorchToTosa(ArrayRef<std::string> customOps) {
this->customOps = customOps;
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tosa::TosaDialect>();
registry.insert<tensor::TensorDialect>();
Expand Down Expand Up @@ -3904,7 +3902,9 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
INSERT_ATENOP_PATTERN(AtenSoftmaxIntOp);
if(customOps.front()=="torch.aten.softmax.int"){
INSERT_ATENOP_PATTERN(AtenSoftmaxIntOp);
}
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(ValsemVariantAtenCopyOp);
Expand All @@ -3925,6 +3925,6 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToTosaPass() {
return std::make_unique<ConvertTorchToTosa>();
mlir::torch::createConvertTorchToTosaPass(ArrayRef<std::string> customOps) {
return std::make_unique<ConvertTorchToTosa>(customOps);
}
6 changes: 3 additions & 3 deletions lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void mlir::torch::registerTorchConversionPasses() {
"contract.",
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);

mlir::PassPipelineRegistration<>(
mlir::PassPipelineRegistration<TorchConversion::TosaBackendPipelineOptions>(
"torch-backend-to-tosa-backend-pipeline",
"Pipeline lowering torch backend contract to TOSA backend "
"contract.",
Expand Down Expand Up @@ -96,8 +96,8 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
}

void TorchConversion::createTorchBackendToTosaBackendPipeline(
OpPassManager &pm) {
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
OpPassManager &pm, const TosaBackendPipelineOptions &options) {
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass(options.customOps));
// Perform rank broadcasting so TosaToLinalg pass works
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());

Expand Down
17 changes: 0 additions & 17 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -929,21 +929,4 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten
func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !torch.vtensor<[1,12,5,5],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,5,5],f32> {
%0 = torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,12,5,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,5,5],f32>
return %0 : !torch.vtensor<[1,12,5,5],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.none
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = "tosa.custom"(%[[VAL_1]]) {dim = 1 : i64, identifier = "softmax"} : (tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[2,3],f32>
// CHECK: }
func.func @torch.aten.softmax.int$cst_dim(%t: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
%none = torch.constant.none
%dim = torch.constant.int 1
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.vtensor<[2,3],f32>, !torch.int, !torch.none -> !torch.vtensor<[2,3],f32>
return %ret : !torch.vtensor<[2,3],f32>
}
19 changes: 19 additions & 0 deletions test/Conversion/TorchToTosa/custom.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa="custom-ops=torch.aten.softmax.int" -split-input-file -verify-diagnostics | FileCheck %s

// -----
// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.none
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// CHECK: %[[VAL_5:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]]) {identifier = "softmax"} : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x3xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[2,3],f32>
// CHECK: }
func.func @torch.aten.softmax.int$cst_dim(%t: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
%none = torch.constant.none
%dim = torch.constant.int 1
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.vtensor<[2,3],f32>, !torch.int, !torch.none -> !torch.vtensor<[2,3],f32>
return %ret : !torch.vtensor<[2,3],f32>
}

0 comments on commit 8d0c781

Please sign in to comment.