Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to support self-define torch op? #1212

Closed
KangHe000 opened this issue Aug 11, 2022 · 3 comments
Closed

How to support self-define torch op? #1212

KangHe000 opened this issue Aug 11, 2022 · 3 comments

Comments

@KangHe000
Copy link

Hi, I have some special self-define torch ops, and want to lower them to torch dialect.
Does torch-mlir provide the export mechanism to support self-define torch ops ?
When export to onnx, I can use torch.autograd.function to meet this kind of needs, but torch-mlir seems don't support that.
I got the error below:

python builtin <built-in method apply of FunctionMeta object at 0x56353a8494f0> is currently not supported in Torchscript:
  File "my_symbolic.py", line 20
    def forward(self, x):
        return TestKernel.apply(x)
               ~~~~~~~~~~~~~~~~ <--- HERE

When I try to run my compile code:

import torch
import torch.nn as nn
import torch.autograd as AF
import torch_mlir

class TestKernel(AF.Function):
    @staticmethod
    def symbolic(g, inputs):
        return g.op("torch::symbolic", inputs)
    @staticmethod
    def forward(ctx, inputs):
        out = torch.exp(-0.5 * ((inputs - 1) ** 2) / 4)
        return out

class TestKernelTest(nn.Module):
    def __init__(self):
        super(TestKernelTest, self).__init__()

    def forward(self, x):
        return TestKernel.apply(x)

if __name__ == "__main__":
    model = TestKernelTest()
    x = torch.randn(2, 3)
    res = model(x)
    torch.onnx.export(model, x, "test_symbolic.onnx", verbose=True)
    traced = torch.jit.trace(model, x)
    module = torch_mlir.compile(model, x, output_type=torch_mlir.OutputType.TORCH)
    #module = torch_mlir.compile(model, x, output_type=torch_mlir.OutputType.TORCH, use_tracing=True)

@ramiro050
Copy link
Collaborator

Hi @KangHe000,

The error you're getting is from PyTorch, not torch-mlir. Namely, the PyTorch function that is in charge of compiling a module into Torchscript does not seem to support AF.Functions.

In torch-mlir there is currently an open PR that adds custom op support through a different method. The approach is to register a custom op in PyTorch, and then add support for it in torch-mlir.

@KangHe000
Copy link
Author

@ramiro050 Sorry for replying so late, I will try register a custom op in PyTorch. Thanks!

qedawkins pushed a commit to nod-ai/torch-mlir that referenced this issue Oct 3, 2022
@silvasean
Copy link
Contributor

Closing this as it appears we found a solution.
FYI we have a custom ops RFC as well: #1462

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants