Skip to content

Commit

Permalink
Allow passing traced torch.nn.Modules into torch_mlir.compile (#1743
Browse files Browse the repository at this point in the history
)

This commit adds support for passing to `torch_mlir.compile` the
result of running `torch.jit.trace` on a model by relaxing the
condition that checks if the model is already in JIT IR to allow any
`torch.jit.ScriptModule`.

Fixes #1739
  • Loading branch information
ramiro050 authored Dec 22, 2022
1 parent 52669cb commit 3260a1e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
28 changes: 28 additions & 0 deletions python/test/compile_api/already_traced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

# RUN: %PYTHON %s | FileCheck %s

import torch
import torch_mlir

class BasicModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.sin(x)

example_arg = torch.ones(2, 3)
example_args = torch_mlir.ExampleArgs.get(example_arg)

traced = torch.jit.trace(BasicModule(), example_arg)
print(torch_mlir.compile(traced, example_args))
# CHECK: module
# CHECK-DAG: func.func @forward

traced = torch.jit.trace(BasicModule(), example_arg)
try:
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
torch_mlir.compile(traced, torch_mlir.ExampleArgs().add_method("nonexistent", example_arg))
except Exception as e:
print(e)
2 changes: 1 addition & 1 deletion python/torch_mlir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def compile(model: torch.nn.Module,
# backend. This separation should be visible at the Python API level, and
# we can implement a deliberately simplified API like `torch_mlir.compile`
# on top of those building blocks.
if isinstance(model, torch.jit._script.RecursiveScriptModule):
if isinstance(model, torch.jit.ScriptModule):
# If the user already converted the model to JIT IR themselves, just
# do some basic error checking, but take the model as-is.
for method_name in example_args._get_methods():
Expand Down

0 comments on commit 3260a1e

Please sign in to comment.