Skip to content

Commit

Permalink
Re-enable custom op support
Browse files Browse the repository at this point in the history
  • Loading branch information
nithinsubbiah committed Aug 2, 2022
1 parent 82af44d commit 0c76800
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 37 deletions.
19 changes: 8 additions & 11 deletions build_tools/update_shape_lib.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,15 @@ torch_transforms_cpp_dir="${src_dir}/lib/Dialect/Torch/Transforms"
python_packages_dir="${build_dir}/tools/torch-mlir/python_packages"

pypath="${python_packages_dir}/torch_mlir"
# TODO: Re-enable once custom op support is back.
#if [ ! -z ${TORCH_MLIR_EXT_PYTHONPATH} ]; then
# pypath="${pypath}:${TORCH_MLIR_EXT_PYTHONPATH}"
#fi
#ext_module="torch_mlir._torch_mlir_custom_op_example"
#if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
# ext_module="${ext_module},${TORCH_MLIR_EXT_MODULES} "
#fi
if [ ! -z ${TORCH_MLIR_EXT_PYTHONPATH} ]; then
pypath="${pypath}:${TORCH_MLIR_EXT_PYTHONPATH}"
fi
ext_module="torch_mlir._torch_mlir_custom_op_example"
if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
ext_module="${ext_module},${TORCH_MLIR_EXT_MODULES} "
fi

PYTHONPATH="${pypath}" python \
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.shape_lib_gen \
--pytorch_op_extensions=${ext_module} \
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}"

# TODO: Add back to shape_lib_gen invocation once custom op support is back.
# --pytorch_op_extensions=${ext_module} \
19 changes: 8 additions & 11 deletions build_tools/update_torch_ods.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,16 @@ torch_ir_include_dir="${src_dir}/include/torch-mlir/Dialect/Torch/IR"
python_packages_dir="${build_dir}/tools/torch-mlir/python_packages"

pypath="${python_packages_dir}/torch_mlir"
# TODO: Re-enable once custom op support is back.
#if [ ! -z ${TORCH_MLIR_EXT_PYTHONPATH} ]; then
# pypath="${pypath}:${TORCH_MLIR_EXT_PYTHONPATH}"
#fi
#ext_module="torch_mlir._torch_mlir_custom_op_example"
#if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
# ext_module="${ext_module},${TORCH_MLIR_EXT_MODULES}"
#fi
if [ ! -z ${TORCH_MLIR_EXT_PYTHONPATH} ]; then
pypath="${pypath}:${TORCH_MLIR_EXT_PYTHONPATH}"
fi
ext_module="torch_mlir._torch_mlir_custom_op_example"
if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
ext_module="${ext_module},${TORCH_MLIR_EXT_MODULES}"
fi

PYTHONPATH="${pypath}" python \
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen \
--torch_ir_include_dir="${torch_ir_include_dir}" \
--pytorch_op_extensions="${ext_module}" \
--debug_registry_dump="${torch_ir_include_dir}/JITOperatorRegistryDump.txt"

# TODO: Add back to torch_ods_gen invocation once custom op support is back.
# --pytorch_op_extensions="${ext_module}" \
22 changes: 22 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9292,3 +9292,25 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
}];
}

def Torch_TorchMlirCustomOpExampleIdentityOp : Torch_Op<"_torch_mlir_custom_op_example.identity", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `_torch_mlir_custom_op_example::identity : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$t
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult TorchMlirCustomOpExampleIdentityOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void TorchMlirCustomOpExampleIdentityOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
3 changes: 1 addition & 2 deletions lib/Conversion/TorchToLinalg/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
add_mlir_conversion_library(TorchMLIRTorchToLinalg
# TODO: Re-enable after MacOS support is fixed for the custom op extension.
# CustomOpExample.cpp
CustomOpExample.cpp
DataMovement.cpp
IndirectDataMovement.cpp
Linear.cpp
Expand Down
54 changes: 54 additions & 0 deletions lib/Conversion/TorchToLinalg/CustomOpExample.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"

#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "Utils.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {
class ConvertCustomOpExample
: public OpConversionPattern<TorchMlirCustomOpExampleIdentityOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(TorchMlirCustomOpExampleIdentityOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter
) const override {
// Type checks.
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// Since the example op does nothing, we simply replace the uses of the
// return value with its argument, then remove the op.
rewriter.replaceOp(op, op->getOperands());

return success();
}
};
} // namespace

void mlir::torch::torch_to_linalg::populateCustomOpExamplePatternsAndLegality(
TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<TorchMlirCustomOpExampleIdentityOp>();
patterns.add<ConvertCustomOpExample>(typeConverter, context);
}
6 changes: 3 additions & 3 deletions lib/Conversion/TorchToLinalg/PopulatePatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ void populateIndirectDataMovementPatternsAndLegality(
void populateTensorConstructorsPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
//void populateCustomOpExamplePatternsAndLegality(TypeConverter &typeConverter,
// RewritePatternSet &patterns,
// ConversionTarget &target);
void populateCustomOpExamplePatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);

} // namespace torch_to_linalg
} // namespace torch
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class ConvertTorchToLinalg

RewritePatternSet patterns(context);

//torch_to_linalg::populateCustomOpExamplePatternsAndLegality(
// typeConverter, patterns, target);
torch_to_linalg::populateCustomOpExamplePatternsAndLegality(
typeConverter, patterns, target);
torch_to_linalg::populateTensorScalarInteropPatternsAndLegality(
typeConverter, patterns, target);
torch_to_linalg::populateLinearPatternsAndLegality(typeConverter, patterns,
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,8 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroOp,
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) {
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp,
TorchMlirCustomOpExampleIdentityOp>(op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}

Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6853,6 +6853,10 @@ module {
%3 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %1, %arg3, %2) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %3 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn._torch_mlir_custom_op_example.identity"(%arg0: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
}
)mlir");
#pragma clang diagnostic pop
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1156,9 +1156,8 @@ def aten〇linalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[Lis
dim = list(range(len(self)))
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)

# TODO: Re-enable after MacOS support is fixed for the extension.
#def _torch_mlir_custom_op_example〇identity(t: List[int]) -> List[int]:
# return upstream_shape_functions.unary(t)
def _torch_mlir_custom_op_example〇identity(t: List[int]) -> List[int]:
return upstream_shape_functions.unary(t)

# ==============================================================================
# Shape library generator main().
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,7 @@ def emit_with_mutating_variants(key, **kwargs):
# extension.
# ==========================================================================

# TODO: Re-enable after MacOS support is fixed for the extension.
#emit("_torch_mlir_custom_op_example::identity : (Tensor) -> (Tensor)")
emit("_torch_mlir_custom_op_example::identity : (Tensor) -> (Tensor)")


def dump_registered_ops(outfile: TextIO, registry: Registry):
Expand Down
3 changes: 1 addition & 2 deletions python/torch_mlir_e2e_test/test_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,4 @@ def register_all_tests():
from . import return_types
from . import control_flow
from . import stats
# TODO: Re-enable after MacOS support is fixed for the extension.
#from . import custom_op_example
from . import custom_op_example

0 comments on commit 0c76800

Please sign in to comment.