diff --git a/examples/custom_op_demo.py b/examples/custom_op_demo.py new file mode 100644 index 000000000000..d9c3c8289505 --- /dev/null +++ b/examples/custom_op_demo.py @@ -0,0 +1,95 @@ +import torch +import torch.utils.cpp_extension +import torch_mlir +from torch_mlir import run_pipeline_with_repro_report +from torch_mlir.ir import BoolAttr, Context, Module, InsertionPoint, Location +from torch_mlir_e2e_test.annotations import export, annotate_args + + +def identity(_5: torch.Tensor): + return _5 + + +goofy_lib = torch.library.Library("goofy", "DEF") +goofy_lib.define("identity(Tensor t) -> Tensor") +goofy_lib.impl("identity", identity) + + +class CustomOpExampleModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a): + b = 2 * a + return torch.ops.goofy.identity(b) + + +mod = CustomOpExampleModule() +mod.eval() + +module = torch_mlir.compile(mod, torch.ones(3, 4), output_type="raw") + +pipeline = ( + "symbol-dce," + "torch-prepare-for-globalize-object-graph," + "torch-globalize-object-graph," + "symbol-dce," + "inline{default-pipeline= max-iterations=4 }," + "torch-adjust-calling-conventions" +) + +run_pipeline_with_repro_report( + module, pipeline=f"builtin.module({pipeline})", description="" +) +print(module) + +forward = module.operation.regions[0].blocks[0].operations[1] +goofy_op = forward.operation.regions[0].blocks[0].operations[4] +goofy_op.attributes["has_value_semantics"] = BoolAttr.get(True, context=module.context) + +print(module) + +abstract_interp_src = """\ +func.func @__torch_mlir_shape_fn.operator.goofy.identity(%arg0: !torch.list) -> !torch.list { + return %arg0 : !torch.list +} +func.func @__torch_mlir_dtype_fn.operator.goofy.identity(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int { + return %arg1 : !torch.int +} +""" + +with Location.unknown(module.context) as loc: + abstract_interp_module = Module.parse(abstract_interp_src) + with InsertionPoint.at_block_begin(module.body) as ip: + shape_fn = abstract_interp_module.body.operations[0] + dtype_fn = abstract_interp_module.body.operations[1] + InsertionPoint.insert(ip, shape_fn.detach_from_parent()) + InsertionPoint.insert(ip, dtype_fn.detach_from_parent()) + +print(module) + +run_pipeline_with_repro_report( + module, + pipeline="builtin.module(func.func(torch-reduce-op-variants,torch-maximize-value-semantics))", + description="", +) + +print(module) + +run_pipeline_with_repro_report( + module, + pipeline="builtin.module(torch-lower-to-backend-contract{backend-legal-ops=torch.operator decompose=true max-iterations=10})", + description="", +) + +shape_fn.detach_from_parent() +dtype_fn.detach_from_parent() + +print(module) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 889a29908de0..4d1176e4a469 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -829,7 +829,8 @@ def Torch_DerefineOp : Torch_Op<"derefine", [ } def Torch_OperatorOp : Torch_Op<"operator", [ - AllowsTypeRefinement + AllowsTypeRefinement, + HasValueSemantics ]> { let summary = "Opaque torch operator"; let description = [{ diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 7cc699c04762..d4ac36a1c4e8 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -191,7 +191,10 @@ static bool isValidNonContainerResultType(Type resultType) { resultType.isa() || resultType.isa() || resultType.isa() || - resultType.isa(); + resultType.isa() || + (resultType.isa() && cast(resultType) + .getContainedType() + .isa()); } static LogicalResult validateReturns(func::FuncOp func) { diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 04bbe0220e71..87085f9b8ffa 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -66,7 +66,16 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable( libFuncArgsBuilder) { Location loc = op->getLoc(); MLIRContext *context = op->getContext(); - auto name = op->getName().stripDialect(); + + std::string name_; + if (isa(op)) { + auto opOp = cast(op); + auto opName = opOp->getAttr("name").cast().getValue(); + name_ = "operator." + opName.str(); + } else { + name_ = op->getName().stripDialect(); + } + StringRef name = name_; // For value-semantic variant ops, i.e. valsem-ops (ops that are // mechanically consistent with existing torch conventions of in-place vs. // out-of-place (value-semantic) variants), remove the prefix when @@ -76,9 +85,17 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable( std::string libFuncName = (getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str(); auto libFunc = library.lookupSymbol(libFuncName); - if (!libFunc) - return success(); - libFuncNamesUsed.push_back(libFuncName); + if (!libFunc) { + auto parentModule = op->getParentOfType(); + if (parentModule) + libFunc = + op->getParentOfType().lookupSymbol(libFuncName); + if (!libFunc) + return success(); + } else { + libFuncNamesUsed.push_back(libFuncName); + } + OpBuilder b(op); Operation *calculate = createCalculateOp(b, loc, op->getResultTypes(), libFuncKind);