From 3260a1ea6e885a2b19dbcb391ea22c5cbf742047 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Thu, 22 Dec 2022 08:39:55 -0800 Subject: [PATCH] Allow passing traced `torch.nn.Module`s into `torch_mlir.compile` (#1743) This commit adds support for passing to `torch_mlir.compile` the result of running `torch.jit.trace` on a model by relaxing the condition that checks if the model is already in JIT IR to allow any `torch.jit.ScriptModule`. Fixes https://github.com/llvm/torch-mlir/issues/1739 --- python/test/compile_api/already_traced.py | 28 +++++++++++++++++++++++ python/torch_mlir/__init__.py | 2 +- 2 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 python/test/compile_api/already_traced.py diff --git a/python/test/compile_api/already_traced.py b/python/test/compile_api/already_traced.py new file mode 100644 index 000000000000..a719eb743c73 --- /dev/null +++ b/python/test/compile_api/already_traced.py @@ -0,0 +1,28 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +import torch +import torch_mlir + +class BasicModule(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.sin(x) + +example_arg = torch.ones(2, 3) +example_args = torch_mlir.ExampleArgs.get(example_arg) + +traced = torch.jit.trace(BasicModule(), example_arg) +print(torch_mlir.compile(traced, example_args)) +# CHECK: module +# CHECK-DAG: func.func @forward + +traced = torch.jit.trace(BasicModule(), example_arg) +try: + # CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition. + torch_mlir.compile(traced, torch_mlir.ExampleArgs().add_method("nonexistent", example_arg)) +except Exception as e: + print(e) diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 9bcf4ada2ebf..3f08bb17365d 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -314,7 +314,7 @@ def compile(model: torch.nn.Module, # backend. This separation should be visible at the Python API level, and # we can implement a deliberately simplified API like `torch_mlir.compile` # on top of those building blocks. - if isinstance(model, torch.jit._script.RecursiveScriptModule): + if isinstance(model, torch.jit.ScriptModule): # If the user already converted the model to JIT IR themselves, just # do some basic error checking, but take the model as-is. for method_name in example_args._get_methods():