Skip to content

Commit

Permalink
Allow torch-mlir to support PyTorch extensions. (llvm#895)
Browse files Browse the repository at this point in the history
PyTorch allows new operators to be registered dynamically in modules.
Torch-mlir already makes it fairly straightforward to add support for
new operators, and this commit just extends that support to allow new
PyTorch ops to come from a external module.

This does *not* allow ops to be dynamically loaded into torch-mlir.
Torch-mlir must still be compiled with support built-in.

Add a `_torch_mlir_custom_op_example` subpackage to `torch_mlir` which
registers an demonstration op. It will not be imported by default when
importing torch_mlir. It's strictly for testing and documentation.

Adds an end-to-end test for the `torch_mlir_custom_op_example::identity` op.

With all these changes, we should now be actively testing PyTorch extension
support with all future patches.
  • Loading branch information
rdadolf authored and JakopinA committed Jun 23, 2022
1 parent 4bf3452 commit a4f1b6c
Show file tree
Hide file tree
Showing 34 changed files with 923 additions and 65 deletions.
27 changes: 24 additions & 3 deletions build_tools/update_shape_lib.sh
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
#!/bin/bash
# Updates auto-generated shape library files for the `torch` dialect.
set -e
#
# Environment variables:
# TORCH_MLIR_EXT_MODULES: comma-separated list of python module names
# which register custom PyTorch operators upon being imported.
# TORCH_MLIR_EXT_PYTHONPATH: colon-separated list of paths necessary
# for importing PyTorch extensions specified in TORCH_MLIR_EXT_MODULES.
# For more information on supporting custom operators, see:
# ${TORCH_MLIR}/python/torch_mlir/_torch_mlir_custom_op_example/README.md

set -eo pipefail

src_dir="$(realpath $(dirname $0)/..)"
build_dir="$(realpath "${TORCH_MLIR_BUILD_DIR:-$src_dir/build}")"
torch_transforms_cpp_dir="${src_dir}/lib/Dialect/Torch/Transforms"
python_packages_dir="${build_dir}/tools/torch-mlir/python_packages"

#ninja -C "${build_dir}"
PYTHONPATH="${python_packages_dir}/torch_mlir" python \
pypath="${python_packages_dir}/torch_mlir"
# TODO: Re-enable once custom op support is back.
#if [ ! -z ${TORCH_MLIR_EXT_PYTHONPATH} ]; then
# pypath="${pypath}:${TORCH_MLIR_EXT_PYTHONPATH}"
#fi
#ext_module="torch_mlir._torch_mlir_custom_op_example"
#if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
# ext_module="${ext_module},${TORCH_MLIR_EXT_MODULES} "
#fi

PYTHONPATH="${pypath}" python \
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.shape_lib_gen \
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}"

# TODO: Add back to shape_lib_gen invocation once custom op support is back.
# --pytorch_op_extensions=${ext_module} \
27 changes: 24 additions & 3 deletions build_tools/update_torch_ods.sh
Original file line number Diff line number Diff line change
@@ -1,14 +1,35 @@
#!/bin/bash
# Updates auto-generated ODS files for the `torch` dialect.
set -euo pipefail
#
# Environment variables:
# TORCH_MLIR_EXT_MODULES: comma-separated list of python module names
# which register custom PyTorch operators upon being imported.
# TORCH_MLIR_EXT_PYTHONPATH: colon-separated list of paths necessary
# for importing PyTorch extensions specified in TORCH_MLIR_EXT_MODULES.
# For more information on supporting custom operators, see:
# ${TORCH_MLIR}/python/torch_mlir/_torch_mlir_custom_op_example/README.md

set -eo pipefail

src_dir="$(realpath $(dirname $0)/..)"
build_dir="$(realpath "${TORCH_MLIR_BUILD_DIR:-$src_dir/build}")"
torch_ir_include_dir="${src_dir}/include/torch-mlir/Dialect/Torch/IR"
python_packages_dir="${build_dir}/tools/torch-mlir/python_packages"

#ninja -C "${build_dir}"
PYTHONPATH="${python_packages_dir}/torch_mlir" python \
pypath="${python_packages_dir}/torch_mlir"
# TODO: Re-enable once custom op support is back.
#if [ ! -z ${TORCH_MLIR_EXT_PYTHONPATH} ]; then
# pypath="${pypath}:${TORCH_MLIR_EXT_PYTHONPATH}"
#fi
#ext_module="torch_mlir._torch_mlir_custom_op_example"
#if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
# ext_module="${ext_module},${TORCH_MLIR_EXT_MODULES}"
#fi

PYTHONPATH="${pypath}" python \
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen \
--torch_ir_include_dir="${torch_ir_include_dir}" \
--debug_registry_dump="${torch_ir_include_dir}/JITOperatorRegistryDump.txt"

# TODO: Add back to torch_ods_gen invocation once custom op support is back.
# --pytorch_op_extensions="${ext_module}" \
10 changes: 7 additions & 3 deletions docs/adding_a_shape_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ We will use the example of adding support for the `torch.aten.tanh` op.
functions don't get outdated if Torch changes an operator signature.

3. Fill in the body of the shape function. Ideally this will just be a call into
a helper function from `upstream_shape_helpers.py`. But in general, you will
need to write the shape function and test it (see the comments about "Shape
function testing infrastructure" in `shape_lib_gen.py`).
a helper function from
[`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1).
But in general, you will need to write the shape function and test it (see
the comments about "Shape function testing infrastructure" in
`shape_lib_gen.py`). New shape functions should be added upstream following
the example of [this PR](https://github.com/pytorch/pytorch/pull/76889),
though it can be useful to iterate locally in `shape_lib_gen.py` first.

4. Re-run the `build_tools/update_shape_lib.sh` script to update the shape
library. After this step happens, ideally everything "just works" and the
Expand Down
11 changes: 2 additions & 9 deletions docs/shape_lib.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ Shape functions are defined as TorchScript-able Python functions in
The signatures of the shape functions are systematically derived from Torch JIT
operator registry (mainly by replacing `Tensor` with `List[int]` in the operator
signatures). Most shape functions are expected to reuse the upstream helper
functions in
`python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/upstream_shape_helpers.py`.
functions [`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1),
and any new shape functions should be added there.

The `build_tools/update_shape_lib.sh` script invokes `shape_lib_gen.py` to
generate an MLIR module containing the shape functions, which is currently
Expand Down Expand Up @@ -119,10 +119,3 @@ was based on the following goals:
written, which are still a fairly large and non-trivial set.

- To make it as mechanical as possible to add a new shape function.

## TODO

We should develop a workflow with upstream to push our manually-authored shape
functions to live and be tested there. We should also find a way to share with
upstream the mapping between operators and their shape functions. We will be
able to simplify this infrastructure quite a bit once that happens.
8 changes: 6 additions & 2 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@
"MmDagModule_basic",
"Matmul_dot",
"Matmul_3d",
"RsubModule_basic",
"RsubModule_noalpha_basic",
"RsubFloatModule_basic",
"RsubFloatModule_noalpha_basic",
"ElementwiseGtFloatScalarModule_basic",
"ElementwiseGtIntScalarModule_basic",
"ElementwiseGtMixed2ScalarModule_basic",
Expand Down Expand Up @@ -160,4 +160,8 @@
"BaddbmmWithBetaModule_basic",
"BaddbmmBroadcast1DInputModule_basic",
"BaddbmmBroadcast2DInputModule_basic",
"NumpyTRank1Module_basic",
"NumpyTRank2Module_basic",
"NumpyTRankNStaticModule_basic",
"NumpyTRankNDynamicModule_basic",
}
94 changes: 94 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2637,6 +2637,30 @@ def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [
}];
}

def Torch_AtenFloorDivideOp : Torch_Op<"aten.floor_divide", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::floor_divide : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFloorDivideOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenFloorDivideOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [
AllowsTypeRefinement
]> {
Expand Down Expand Up @@ -4125,6 +4149,29 @@ def Torch_AtenBoolTensorOp : Torch_Op<"aten.Bool.Tensor", [
}];
}

def Torch_AtenIsFloatingPointOp : Torch_Op<"aten.is_floating_point", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::is_floating_point : (Tensor) -> (bool)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
Torch_BoolType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIsFloatingPointOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenIsFloatingPointOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenOnesOp : Torch_Op<"aten.ones", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -5955,6 +6002,28 @@ def Torch_AtenTOp : Torch_Op<"aten.t", [
}];
}

def Torch_AtenNumpyTOp : Torch_Op<"aten.numpy_T", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::numpy_T : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNumpyTOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenNumpyTOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenFullOp : Torch_Op<"aten.full", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -6090,6 +6159,31 @@ def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
let hasFolder = 1;
}

def Torch_Aten__Contains__IntListOp : Torch_Op<"aten.__contains__.int_list", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::__contains__.int_list : (int[], int) -> (bool)`";
let arguments = (ins
AnyTorchListOfTorchIntType:$l,
Torch_IntType:$item
);
let results = (outs
Torch_BoolType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten__Contains__IntListOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void Aten__Contains__IntListOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}

def Torch_Aten__Getitem__DictStrOp : Torch_Op<"aten.__getitem__.Dict_str", [
AllowsTypeRefinement,
ReadOnly
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TorchToLinalg/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
add_mlir_conversion_library(TorchMLIRTorchToLinalg
# TODO: Re-enable after MacOS support is fixed for the custom op extension.
# CustomOpExample.cpp
DataMovement.cpp
IndirectDataMovement.cpp
Linear.cpp
Expand Down
9 changes: 6 additions & 3 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,9 +696,12 @@ class ConvertAtenPermuteOp : public OpConversionPattern<AtenPermuteOp> {
for (unsigned i = 0; i < inputRank; i++)
swapExprs.push_back(idExprs[dimensions[i]]);

SmallVector<AffineMap> indexingMaps =
AffineMap::inferFromExprList({idExprs, swapExprs});
SmallVector<StringRef> iteratorTypes(inputRank, "parallel");
AffineMap inputMap = AffineMap::get(inputRank, /*symbolCount=*/0, idExprs,
op->getContext());
AffineMap outputMap = AffineMap::get(inputRank, /*symbolCount=*/0, swapExprs,
op->getContext());
SmallVector<AffineMap> indexingMaps{inputMap, outputMap};
SmallVector<StringRef> iteratorTypes(inputRank, getParallelIteratorTypeName());
auto transpose = rewriter
.create<linalg::GenericOp>(
loc, outVector.getType(), inVector, outVector,
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class ConvertAtenEmbeddingOp : public OpConversionPattern<AtenEmbeddingOp> {
sizes.push_back(embeddingDim);
int64_t resultRank = sizes.size();

auto indicesTy = weight.getType().cast<RankedTensorType>();
auto indicesTy = indices.getType().cast<RankedTensorType>();
int64_t indicesRank = indicesTy.getRank();
SmallVector<AffineExpr> indicesExprs;
for (int i = 0; i < indicesRank; i++)
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TorchToLinalg/PopulatePatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ void populateIndirectDataMovementPatternsAndLegality(
void populateTensorConstructorsPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
//void populateCustomOpExamplePatternsAndLegality(TypeConverter &typeConverter,
// RewritePatternSet &patterns,
// ConversionTarget &target);

} // namespace torch_to_linalg
} // namespace torch
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class ConvertTorchToLinalg

RewritePatternSet patterns(context);

//torch_to_linalg::populateCustomOpExamplePatternsAndLegality(
// typeConverter, patterns, target);
torch_to_linalg::populateTensorScalarInteropPatternsAndLegality(
typeConverter, patterns, target);
torch_to_linalg::populateLinearPatternsAndLegality(typeConverter, patterns,
Expand Down
18 changes: 11 additions & 7 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -736,15 +736,19 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(rsub.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
rsub.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value self = payloadArgs[0];
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype);
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
return b.create<arith::SubFOp>(loc, other, mult);
if (dtype.isa<mlir::FloatType>()) {
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
return b.create<arith::SubFOp>(loc, other, mult);
} else if (dtype.isa<mlir::IntegerType>()) {
Value mult = b.create<arith::MulIOp>(loc, self, alpha);
return b.create<arith::SubIOp>(loc, other, mult);
}
rsub.emitError("unimplemented: dtype other than float and integer "
"types are not supported.");
return nullptr;
}
if (auto mulScalar = dyn_cast<AtenMulScalarOp>(op)) {
Type dtype = converter->convertType(mulScalar.getType())
Expand Down
20 changes: 20 additions & 0 deletions lib/Conversion/TorchToStd/TorchToStd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@ class ConvertAtenDimOp : public OpConversionPattern<AtenDimOp> {
};
} // namespace

namespace {
class ConvertAtenIsFloatingPointOp
: public OpConversionPattern<AtenIsFloatingPointOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenIsFloatingPointOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto tensorType = op.self().getType().cast<BaseTensorType>();
bool result =
tensorType.hasDtype() && tensorType.getDtype().isa<mlir::FloatType>();
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, BoolAttr::get(getContext(), result));
return success();
}
};
} // namespace

namespace {
class ConvertRuntimeAssertOp : public OpConversionPattern<RuntimeAssertOp> {
public:
Expand Down Expand Up @@ -301,6 +319,8 @@ class ConvertTorchToStd : public ConvertTorchToStdBase<ConvertTorchToStd> {
RewritePatternSet patterns(context);
target.addIllegalOp<AtenDimOp>();
patterns.add<ConvertAtenDimOp>(typeConverter, context);
target.addIllegalOp<AtenIsFloatingPointOp>();
patterns.add<ConvertAtenIsFloatingPointOp>(typeConverter, context);
target.addIllegalOp<RuntimeAssertOp>();
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
target.addIllegalOp<AtenNeIntOp, AtenEqIntOp, AtenGtIntOp>();
Expand Down
Loading

0 comments on commit a4f1b6c

Please sign in to comment.