Skip to content

Commit

Permalink
support custom ops using torch.operator
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Mar 20, 2023
1 parent c2ef5f4 commit cc5e2c2
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 7 deletions.
95 changes: 95 additions & 0 deletions examples/custom_op_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import torch.utils.cpp_extension
import torch_mlir
from torch_mlir import run_pipeline_with_repro_report
from torch_mlir.ir import BoolAttr, Context, Module, InsertionPoint, Location
from torch_mlir_e2e_test.annotations import export, annotate_args


def identity(_5: torch.Tensor):
return _5


goofy_lib = torch.library.Library("goofy", "DEF")
goofy_lib.define("identity(Tensor t) -> Tensor")
goofy_lib.impl("identity", identity)


class CustomOpExampleModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, a):
b = 2 * a
return torch.ops.goofy.identity(b)


mod = CustomOpExampleModule()
mod.eval()

module = torch_mlir.compile(mod, torch.ones(3, 4), output_type="raw")

pipeline = (
"symbol-dce,"
"torch-prepare-for-globalize-object-graph,"
"torch-globalize-object-graph,"
"symbol-dce,"
"inline{default-pipeline= max-iterations=4 },"
"torch-adjust-calling-conventions"
)

run_pipeline_with_repro_report(
module, pipeline=f"builtin.module({pipeline})", description=""
)
print(module)

forward = module.operation.regions[0].blocks[0].operations[1]
goofy_op = forward.operation.regions[0].blocks[0].operations[4]
goofy_op.attributes["has_value_semantics"] = BoolAttr.get(True, context=module.context)

print(module)

abstract_interp_src = """\
func.func @__torch_mlir_shape_fn.operator.goofy.identity(%arg0: !torch.list<int>) -> !torch.list<int> {
return %arg0 : !torch.list<int>
}
func.func @__torch_mlir_dtype_fn.operator.goofy.identity(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
return %arg1 : !torch.int
}
"""

with Location.unknown(module.context) as loc:
abstract_interp_module = Module.parse(abstract_interp_src)
with InsertionPoint.at_block_begin(module.body) as ip:
shape_fn = abstract_interp_module.body.operations[0]
dtype_fn = abstract_interp_module.body.operations[1]
InsertionPoint.insert(ip, shape_fn.detach_from_parent())
InsertionPoint.insert(ip, dtype_fn.detach_from_parent())

print(module)

run_pipeline_with_repro_report(
module,
pipeline="builtin.module(func.func(torch-reduce-op-variants,torch-maximize-value-semantics))",
description="",
)

print(module)

run_pipeline_with_repro_report(
module,
pipeline="builtin.module(torch-lower-to-backend-contract{backend-legal-ops=torch.operator decompose=true max-iterations=10})",
description="",
)

shape_fn.detach_from_parent()
dtype_fn.detach_from_parent()

print(module)
5 changes: 4 additions & 1 deletion lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,10 @@ static bool isValidNonContainerResultType(Type resultType) {
resultType.isa<Torch::FloatType>() ||
resultType.isa<Torch::IntType>() ||
resultType.isa<Torch::BoolType>() ||
resultType.isa<Torch::NoneType>();
resultType.isa<Torch::NoneType>() ||
(resultType.isa<Torch::ListType>() && cast<Torch::ListType>(resultType)
.getContainedType()
.isa<Torch::IntType>());
}

static LogicalResult validateReturns(func::FuncOp func) {
Expand Down
19 changes: 17 additions & 2 deletions lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ static Type getContainerOrTensorTypeWithValueSemantics(Type type) {
}
}

static bool operatorOpHasValueSemantics(OperatorOp opOp) {
if (!opOp->hasAttr("has_value_semantics"))
return false;
auto hasValueSemantics =
opOp->getAttr("has_value_semantics").cast<BoolAttr>().getValue();
return hasValueSemantics;
}

namespace {
// Convert value semantic ops operating on mutable arrays to instead operate on
// immutable tensors.
Expand All @@ -61,8 +69,13 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (!op->hasTrait<Torch::OpTrait::HasValueSemantics>())
if (isa<OperatorOp>(op)) {
if (!operatorOpHasValueSemantics(cast<OperatorOp>(op))) {
return rewriter.notifyMatchFailure(op, "does not have value semantics");
}
} else if (!op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
return rewriter.notifyMatchFailure(op, "does not have value semantics");
}

rewriter.startRootUpdate(op);
// Convert all operands.
Expand Down Expand Up @@ -254,7 +267,9 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
target.addIllegalOp<NonValueTensorLiteralOp>();
target.addIllegalOp<AtenBernoulli_FloatOp>();
target.markUnknownOpDynamicallyLegal([](Operation *op) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
(isa<OperatorOp>(op) &&
operatorOpHasValueSemantics(cast<OperatorOp>(op)))) {
auto hasValueSemantics = [](Type t) {
// TODO: Make this an allowlist based on a closed torch dialect
// type system.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ createCalculateYieldCalculationOp(OpBuilder &b, Location loc,
"unsupported `LibraryFunctionKind`");
}

#include <iostream>

LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
Operation *op, ModuleOp library, LibraryFunctionKind libFuncKind,
SmallVector<std::string> &libFuncNamesUsed,
Expand All @@ -66,7 +68,16 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
libFuncArgsBuilder) {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
auto name = op->getName().stripDialect();

std::string name_;
if (isa<OperatorOp>(op)) {
auto opOp = cast<OperatorOp>(op);
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue();
name_ = "operator." + opName.str();
} else {
name_ = op->getName().stripDialect();
}
StringRef name = name_;
// For value-semantic variant ops, i.e. valsem-ops (ops that are
// mechanically consistent with existing torch conventions of in-place vs.
// out-of-place (value-semantic) variants), remove the prefix when
Expand All @@ -76,9 +87,17 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
std::string libFuncName =
(getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str();
auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName);
if (!libFunc)
return success();
libFuncNamesUsed.push_back(libFuncName);
if (!libFunc) {
auto parentModule = op->getParentOfType<ModuleOp>();
if (parentModule)
libFunc =
op->getParentOfType<ModuleOp>().lookupSymbol<func::FuncOp>(libFuncName);
if (!libFunc)
return success();
} else {
libFuncNamesUsed.push_back(libFuncName);
}

OpBuilder b(op);
Operation *calculate =
createCalculateOp(b, loc, op->getResultTypes(), libFuncKind);
Expand Down

0 comments on commit cc5e2c2

Please sign in to comment.