From 7e694131a15cd518d8fa18f902cd0a447ee46d88 Mon Sep 17 00:00:00 2001 From: max Date: Mon, 20 Mar 2023 16:09:37 -0500 Subject: [PATCH 1/2] support custom ops using `torch.operator` --- examples/custom_op_demo.py | 95 +++++++++++++++++++ .../Transforms/AdjustCallingConventions.cpp | 5 +- .../Torch/Transforms/ReduceOpVariants.cpp | 19 +++- .../ReifyAbstractInterpCalculationsUtils.cpp | 25 ++++- 4 files changed, 137 insertions(+), 7 deletions(-) create mode 100644 examples/custom_op_demo.py diff --git a/examples/custom_op_demo.py b/examples/custom_op_demo.py new file mode 100644 index 000000000000..d9c3c8289505 --- /dev/null +++ b/examples/custom_op_demo.py @@ -0,0 +1,95 @@ +import torch +import torch.utils.cpp_extension +import torch_mlir +from torch_mlir import run_pipeline_with_repro_report +from torch_mlir.ir import BoolAttr, Context, Module, InsertionPoint, Location +from torch_mlir_e2e_test.annotations import export, annotate_args + + +def identity(_5: torch.Tensor): + return _5 + + +goofy_lib = torch.library.Library("goofy", "DEF") +goofy_lib.define("identity(Tensor t) -> Tensor") +goofy_lib.impl("identity", identity) + + +class CustomOpExampleModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a): + b = 2 * a + return torch.ops.goofy.identity(b) + + +mod = CustomOpExampleModule() +mod.eval() + +module = torch_mlir.compile(mod, torch.ones(3, 4), output_type="raw") + +pipeline = ( + "symbol-dce," + "torch-prepare-for-globalize-object-graph," + "torch-globalize-object-graph," + "symbol-dce," + "inline{default-pipeline= max-iterations=4 }," + "torch-adjust-calling-conventions" +) + +run_pipeline_with_repro_report( + module, pipeline=f"builtin.module({pipeline})", description="" +) +print(module) + +forward = module.operation.regions[0].blocks[0].operations[1] +goofy_op = forward.operation.regions[0].blocks[0].operations[4] +goofy_op.attributes["has_value_semantics"] = BoolAttr.get(True, context=module.context) + +print(module) + +abstract_interp_src = """\ +func.func @__torch_mlir_shape_fn.operator.goofy.identity(%arg0: !torch.list) -> !torch.list { + return %arg0 : !torch.list +} +func.func @__torch_mlir_dtype_fn.operator.goofy.identity(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int { + return %arg1 : !torch.int +} +""" + +with Location.unknown(module.context) as loc: + abstract_interp_module = Module.parse(abstract_interp_src) + with InsertionPoint.at_block_begin(module.body) as ip: + shape_fn = abstract_interp_module.body.operations[0] + dtype_fn = abstract_interp_module.body.operations[1] + InsertionPoint.insert(ip, shape_fn.detach_from_parent()) + InsertionPoint.insert(ip, dtype_fn.detach_from_parent()) + +print(module) + +run_pipeline_with_repro_report( + module, + pipeline="builtin.module(func.func(torch-reduce-op-variants,torch-maximize-value-semantics))", + description="", +) + +print(module) + +run_pipeline_with_repro_report( + module, + pipeline="builtin.module(torch-lower-to-backend-contract{backend-legal-ops=torch.operator decompose=true max-iterations=10})", + description="", +) + +shape_fn.detach_from_parent() +dtype_fn.detach_from_parent() + +print(module) diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 7cc699c04762..d4ac36a1c4e8 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -191,7 +191,10 @@ static bool isValidNonContainerResultType(Type resultType) { resultType.isa() || resultType.isa() || resultType.isa() || - resultType.isa(); + resultType.isa() || + (resultType.isa() && cast(resultType) + .getContainedType() + .isa()); } static LogicalResult validateReturns(func::FuncOp func) { diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 78bc0703e46c..73a9f2e5ccd1 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -52,6 +52,14 @@ static Type getContainerOrTensorTypeWithValueSemantics(Type type) { } } +static bool operatorOpHasValueSemantics(OperatorOp opOp) { + if (!opOp->hasAttr("has_value_semantics")) + return false; + auto hasValueSemantics = + opOp->getAttr("has_value_semantics").cast().getValue(); + return hasValueSemantics; +} + namespace { // Convert value semantic ops operating on mutable arrays to instead operate on // immutable tensors. @@ -61,8 +69,13 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (!op->hasTrait()) + if (isa(op)) { + if (!operatorOpHasValueSemantics(cast(op))) { + return rewriter.notifyMatchFailure(op, "does not have value semantics"); + } + } else if (!op->hasTrait()) { return rewriter.notifyMatchFailure(op, "does not have value semantics"); + } rewriter.startRootUpdate(op); // Convert all operands. @@ -254,7 +267,9 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase { target.addIllegalOp(); target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *op) { - if (op->hasTrait()) { + if (op->hasTrait() || + (isa(op) && + operatorOpHasValueSemantics(cast(op)))) { auto hasValueSemantics = [](Type t) { // TODO: Make this an allowlist based on a closed torch dialect // type system. diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 04bbe0220e71..87085f9b8ffa 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -66,7 +66,16 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable( libFuncArgsBuilder) { Location loc = op->getLoc(); MLIRContext *context = op->getContext(); - auto name = op->getName().stripDialect(); + + std::string name_; + if (isa(op)) { + auto opOp = cast(op); + auto opName = opOp->getAttr("name").cast().getValue(); + name_ = "operator." + opName.str(); + } else { + name_ = op->getName().stripDialect(); + } + StringRef name = name_; // For value-semantic variant ops, i.e. valsem-ops (ops that are // mechanically consistent with existing torch conventions of in-place vs. // out-of-place (value-semantic) variants), remove the prefix when @@ -76,9 +85,17 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable( std::string libFuncName = (getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str(); auto libFunc = library.lookupSymbol(libFuncName); - if (!libFunc) - return success(); - libFuncNamesUsed.push_back(libFuncName); + if (!libFunc) { + auto parentModule = op->getParentOfType(); + if (parentModule) + libFunc = + op->getParentOfType().lookupSymbol(libFuncName); + if (!libFunc) + return success(); + } else { + libFuncNamesUsed.push_back(libFuncName); + } + OpBuilder b(op); Operation *calculate = createCalculateOp(b, loc, op->getResultTypes(), libFuncKind); From ae3556738fa15de02d60bfea34cff134a8a0fd16 Mon Sep 17 00:00:00 2001 From: max Date: Tue, 21 Mar 2023 08:26:56 -0500 Subject: [PATCH 2/2] add `HasValueSemantics` to `torch.operator` instead of special-casing in `ReduceOpVariants` --- .../torch-mlir/Dialect/Torch/IR/TorchOps.td | 3 ++- .../Torch/Transforms/ReduceOpVariants.cpp | 19 ++----------------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 889a29908de0..4d1176e4a469 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -829,7 +829,8 @@ def Torch_DerefineOp : Torch_Op<"derefine", [ } def Torch_OperatorOp : Torch_Op<"operator", [ - AllowsTypeRefinement + AllowsTypeRefinement, + HasValueSemantics ]> { let summary = "Opaque torch operator"; let description = [{ diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 73a9f2e5ccd1..78bc0703e46c 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -52,14 +52,6 @@ static Type getContainerOrTensorTypeWithValueSemantics(Type type) { } } -static bool operatorOpHasValueSemantics(OperatorOp opOp) { - if (!opOp->hasAttr("has_value_semantics")) - return false; - auto hasValueSemantics = - opOp->getAttr("has_value_semantics").cast().getValue(); - return hasValueSemantics; -} - namespace { // Convert value semantic ops operating on mutable arrays to instead operate on // immutable tensors. @@ -69,13 +61,8 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (isa(op)) { - if (!operatorOpHasValueSemantics(cast(op))) { - return rewriter.notifyMatchFailure(op, "does not have value semantics"); - } - } else if (!op->hasTrait()) { + if (!op->hasTrait()) return rewriter.notifyMatchFailure(op, "does not have value semantics"); - } rewriter.startRootUpdate(op); // Convert all operands. @@ -267,9 +254,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase { target.addIllegalOp(); target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *op) { - if (op->hasTrait() || - (isa(op) && - operatorOpHasValueSemantics(cast(op)))) { + if (op->hasTrait()) { auto hasValueSemantics = [](Type t) { // TODO: Make this an allowlist based on a closed torch dialect // type system.