-
Notifications
You must be signed in to change notification settings - Fork 514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Support pushing custom ops through backend-contract using torch.operator
#1959
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import torch | ||
import torch.utils.cpp_extension | ||
import torch_mlir | ||
from torch_mlir import run_pipeline_with_repro_report | ||
from torch_mlir.ir import BoolAttr, Context, Module, InsertionPoint, Location | ||
from torch_mlir_e2e_test.annotations import export, annotate_args | ||
|
||
|
||
def identity(_5: torch.Tensor): | ||
return _5 | ||
|
||
|
||
goofy_lib = torch.library.Library("goofy", "DEF") | ||
goofy_lib.define("identity(Tensor t) -> Tensor") | ||
goofy_lib.impl("identity", identity) | ||
|
||
|
||
class CustomOpExampleModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
@export | ||
@annotate_args( | ||
[ | ||
None, | ||
([-1, -1], torch.float32, True), | ||
] | ||
) | ||
def forward(self, a): | ||
b = 2 * a | ||
return torch.ops.goofy.identity(b) | ||
|
||
|
||
mod = CustomOpExampleModule() | ||
mod.eval() | ||
|
||
module = torch_mlir.compile(mod, torch.ones(3, 4), output_type="raw") | ||
|
||
pipeline = ( | ||
"symbol-dce," | ||
"torch-prepare-for-globalize-object-graph," | ||
"torch-globalize-object-graph," | ||
"symbol-dce," | ||
"inline{default-pipeline= max-iterations=4 }," | ||
"torch-adjust-calling-conventions" | ||
) | ||
|
||
run_pipeline_with_repro_report( | ||
module, pipeline=f"builtin.module({pipeline})", description="" | ||
) | ||
print(module) | ||
|
||
forward = module.operation.regions[0].blocks[0].operations[1] | ||
goofy_op = forward.operation.regions[0].blocks[0].operations[4] | ||
goofy_op.attributes["has_value_semantics"] = BoolAttr.get(True, context=module.context) | ||
|
||
print(module) | ||
|
||
abstract_interp_src = """\ | ||
func.func @__torch_mlir_shape_fn.operator.goofy.identity(%arg0: !torch.list<int>) -> !torch.list<int> { | ||
return %arg0 : !torch.list<int> | ||
} | ||
func.func @__torch_mlir_dtype_fn.operator.goofy.identity(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int { | ||
return %arg1 : !torch.int | ||
} | ||
""" | ||
|
||
with Location.unknown(module.context) as loc: | ||
abstract_interp_module = Module.parse(abstract_interp_src) | ||
with InsertionPoint.at_block_begin(module.body) as ip: | ||
shape_fn = abstract_interp_module.body.operations[0] | ||
dtype_fn = abstract_interp_module.body.operations[1] | ||
InsertionPoint.insert(ip, shape_fn.detach_from_parent()) | ||
InsertionPoint.insert(ip, dtype_fn.detach_from_parent()) | ||
|
||
print(module) | ||
|
||
run_pipeline_with_repro_report( | ||
module, | ||
pipeline="builtin.module(func.func(torch-reduce-op-variants,torch-maximize-value-semantics))", | ||
description="", | ||
) | ||
|
||
print(module) | ||
|
||
run_pipeline_with_repro_report( | ||
module, | ||
pipeline="builtin.module(torch-lower-to-backend-contract{backend-legal-ops=torch.operator decompose=true max-iterations=10})", | ||
description="", | ||
) | ||
|
||
shape_fn.detach_from_parent() | ||
dtype_fn.detach_from_parent() | ||
|
||
print(module) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -191,7 +191,10 @@ static bool isValidNonContainerResultType(Type resultType) { | |
resultType.isa<Torch::FloatType>() || | ||
resultType.isa<Torch::IntType>() || | ||
resultType.isa<Torch::BoolType>() || | ||
resultType.isa<Torch::NoneType>(); | ||
resultType.isa<Torch::NoneType>() || | ||
(resultType.isa<Torch::ListType>() && cast<Torch::ListType>(resultType) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't be doing this. This function was created with the goal of preventing something like a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's fine but if you want to enable user-provided shape and dtype functions in the same parent module then there needs to be special casing for them. The alternative, is to provide some mechanism for passing handles to a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. User provided shape and dtype functions will be handled exactly the same way that current shape and dtype functions are handled. The plan is to load them from a |
||
.getContainedType() | ||
.isa<Torch::IntType>()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shape functions return |
||
} | ||
|
||
static LogicalResult validateReturns(func::FuncOp func) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,7 +66,16 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable( | |
libFuncArgsBuilder) { | ||
Location loc = op->getLoc(); | ||
MLIRContext *context = op->getContext(); | ||
auto name = op->getName().stripDialect(); | ||
|
||
std::string name_; | ||
if (isa<OperatorOp>(op)) { | ||
auto opOp = cast<OperatorOp>(op); | ||
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue(); | ||
name_ = "operator." + opName.str(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shape/dtype functions for |
||
} else { | ||
name_ = op->getName().stripDialect(); | ||
} | ||
StringRef name = name_; | ||
// For value-semantic variant ops, i.e. valsem-ops (ops that are | ||
// mechanically consistent with existing torch conventions of in-place vs. | ||
// out-of-place (value-semantic) variants), remove the prefix when | ||
|
@@ -76,9 +85,17 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable( | |
std::string libFuncName = | ||
(getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str(); | ||
auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName); | ||
if (!libFunc) | ||
return success(); | ||
libFuncNamesUsed.push_back(libFuncName); | ||
if (!libFunc) { | ||
auto parentModule = op->getParentOfType<ModuleOp>(); | ||
if (parentModule) | ||
libFunc = | ||
op->getParentOfType<ModuleOp>().lookupSymbol<func::FuncOp>(libFuncName); | ||
if (!libFunc) | ||
return success(); | ||
} else { | ||
libFuncNamesUsed.push_back(libFuncName); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} | ||
|
||
OpBuilder b(op); | ||
Operation *calculate = | ||
createCalculateOp(b, loc, op->getResultTypes(), libFuncKind); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same thing as the "classical" torch custom op registration; this
torch.jit.trace
s to