diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 6f0f09dec57c..98ee5151ed42 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -90,7 +90,8 @@ std::unique_ptr> createMaximizeValueSemanticsPass(); std::unique_ptr> createRefinePublicReturnPass(); -std::unique_ptr> createDecomposeComplexOpsPass(); +std::unique_ptr> +createDecomposeComplexOpsPass(ArrayRef legalOps); std::unique_ptr> createPreprocessShapeLibraryPass(); diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index bc99de7c9042..c1ce31aa6611 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -217,7 +217,9 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> { def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> { let summary = "Decompose complicated torch operations"; - let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()"; + let constructor = [{ + mlir::torch::Torch::createDecomposeComplexOpsPass(/*legalOps=*/{}) + }]; let options = [ ListOption<"legalOps", "legal-ops", "std::string", "List of operation names that should be considered legal", diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 555b96853320..ac88bc5875bc 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2646,7 +2646,9 @@ class DecomposeComplexOpsPass } }; } // namespace + std::unique_ptr> -mlir::torch::Torch::createDecomposeComplexOpsPass() { - return std::make_unique(); +mlir::torch::Torch::createDecomposeComplexOpsPass( + ArrayRef legalOps) { + return std::make_unique(legalOps); } diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index 63c134e337d9..3681eef27e4d 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -140,7 +140,8 @@ void mlir::torch::Torch::createTorchSimplificationPipeline( // basic blocks. pm.addNestedPass(createCanonicalizerPass()); if (options.decompose) { - pm.addNestedPass(Torch::createDecomposeComplexOpsPass()); + pm.addNestedPass( + Torch::createDecomposeComplexOpsPass(options.backendLegalOps)); pm.addNestedPass(createCanonicalizerPass()); } } diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 07bbade42351..e6bbd7e0b4fd 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -6,6 +6,9 @@ from typing import Sequence, Union, List from enum import Enum +import sys +from io import StringIO + import torch from torch_mlir.passmanager import PassManager @@ -116,6 +119,19 @@ def like(tensor: torch.Tensor, dynamic_axes: List[int] = None): return TensorPlaceholder(shape, tensor.dtype) +# The set of ops that are considered legal for each backend. +# These are currently quite load-bearing, since different backends might be +# missing patterns for decomposed forms of certain ops. +# TODO: Tighten up the definition of these "conditionally legal for backends" +# ops in the backend contract, and move these lists somewhere deeper in the +# compiler where each backend can "own" its set of legal ops. +BACKEND_LEGAL_OPS = { + OutputType.TOSA: [], + OutputType.LINALG_ON_TENSORS: [], + OutputType.MHLO: [], +} + + _example_arg = Union[TensorPlaceholder, torch.Tensor] @@ -209,14 +225,32 @@ def compile(model: torch.nn.Module, mb = ModuleBuilder() import_options = ImportOptions() import_options.ignoreExistingTensorShapesAndDtypes = ignore_traced_shapes - mb.import_module(scripted._c, class_annotator, import_options) + try: + original_stderr = sys.stderr + sys.stderr = StringIO() + # Import the TorchScript module to MLIR + mb.import_module(scripted._c, class_annotator, import_options) + except Exception as e: + raise Exception(f""" +PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: +### Importer C++ Exception: +{e} +### Importer Diagnostics: +{sys.stderr.getvalue()} +""") from None + finally: + sys.stderr = original_stderr if output_type == OutputType.RAW: return mb.module - run_pipeline_with_repro_report(mb.module, - "torchscript-module-to-torch-backend-pipeline", - "Lowering TorchScript IR -> Torch Backend IR") + backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, []) + option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}" + run_pipeline_with_repro_report( + mb.module, + f"torchscript-module-to-torch-backend-pipeline{option_string}", + "Lowering TorchScript IR -> Torch Backend IR", + ) if verbose: print("\n====================") diff --git a/python/torch_mlir_e2e_test/torchscript/configs/linalg_on_tensors_backend.py b/python/torch_mlir_e2e_test/torchscript/configs/linalg_on_tensors_backend.py index adb4e3cca3ed..d846c200db7d 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/linalg_on_tensors_backend.py +++ b/python/torch_mlir_e2e_test/torchscript/configs/linalg_on_tensors_backend.py @@ -3,23 +3,18 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -import sys from typing import Any -from io import StringIO -import os -import tempfile -import numpy as np import torch +import torch_mlir from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem -from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders from .utils import ( recursively_convert_to_numpy, recursively_convert_from_numpy, - convert_torchscript_module_to_torch_backend_contract_mlir, ) @@ -34,14 +29,9 @@ def __init__(self, backend: LinalgOnTensorsBackend): self.backend = backend def compile(self, program: torch.nn.Module) -> Any: - - module = convert_torchscript_module_to_torch_backend_contract_mlir( - program) - - run_pipeline_with_repro_report( - module, - "torch-backend-to-linalg-on-tensors-backend-pipeline", - "Lower Torch Backend IR -> Linalg-on-Tensors Backend IR") + example_args = convert_annotations_to_placeholders(program.forward) + module = torch_mlir.compile( + program, example_args, output_type="linalg-on-tensors") return self.backend.compile(module) diff --git a/python/torch_mlir_e2e_test/torchscript/configs/tosa_backend.py b/python/torch_mlir_e2e_test/torchscript/configs/tosa_backend.py index f157433c7366..becfdcdfe624 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/tosa_backend.py +++ b/python/torch_mlir_e2e_test/torchscript/configs/tosa_backend.py @@ -3,22 +3,17 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -import sys from typing import Any -from io import StringIO -import os -import tempfile -import numpy as np import torch +import torch_mlir from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem -from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders from .utils import ( recursively_convert_to_numpy, recursively_convert_from_numpy, - convert_torchscript_module_to_torch_backend_contract_mlir, ) @@ -33,14 +28,9 @@ def __init__(self, backend: TosaBackend): self.backend = backend def compile(self, program: torch.nn.Module) -> Any: - - module = convert_torchscript_module_to_torch_backend_contract_mlir( - program) - - run_pipeline_with_repro_report( - module, - "torch-backend-to-tosa-backend-pipeline", - "Lower Torch Backend IR -> TOSA Backend IR") + example_args = convert_annotations_to_placeholders(program.forward) + module = torch_mlir.compile( + program, example_args, output_type="tosa") return self.backend.compile(module) diff --git a/python/torch_mlir_e2e_test/torchscript/configs/utils.py b/python/torch_mlir_e2e_test/torchscript/configs/utils.py index 6e5da13650a3..63bbca7b0ddf 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/utils.py +++ b/python/torch_mlir_e2e_test/torchscript/configs/utils.py @@ -50,41 +50,3 @@ def recursively_convert_from_numpy(o: Any): if isinstance(o, int): return o raise Exception(f"Unexpected Python function output: {o}") - - -def convert_torchscript_module_to_torch_backend_contract_mlir(program: torch.nn.Module): - """Perform common lowering from TorchScript to Torch MLIR - - Returns an MLIR module that satisfies the Torch backend contract. - """ - mb = ModuleBuilder() - scripted = torch.jit.script(program) - class_annotator = ClassAnnotator() - - extract_annotations(program, scripted, class_annotator) - - - # TODO: Find a way to make each of these calls own its own - # "debuggable error report" situation. - try: - original_stderr = sys.stderr - sys.stderr = StringIO() - # Import the TorchScript module to MLIR - mb.import_module(scripted._c, class_annotator) - except Exception as e: - raise Exception(f""" -PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: -Exception: -{e} -Diagnostics: -{sys.stderr.getvalue()} -""") from None - finally: - sys.stderr = original_stderr - - run_pipeline_with_repro_report( - mb.module, - "torchscript-module-to-torch-backend-pipeline", - "Lowering TorchScript Object Graph IR -> Torch Backend IR") - - return mb.module diff --git a/python/torch_mlir_e2e_test/utils.py b/python/torch_mlir_e2e_test/utils.py index e69de29bb2d1..23d7fca84cea 100644 --- a/python/torch_mlir_e2e_test/utils.py +++ b/python/torch_mlir_e2e_test/utils.py @@ -0,0 +1,22 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from torch_mlir import TensorPlaceholder +from torch_mlir_e2e_test.torchscript.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME + +def convert_annotations_to_placeholders(forward_method): + """Converts the annotations on a forward method into tensor placeholders. + + These placeholders are suitable for being passed to `torch_mlir.compile`. + """ + annotations = getattr(forward_method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME) + placeholders = [] + # Skip the "self" annotation. + for annotation in annotations[1:]: + if not annotation[2]: + raise ValueError( + "Can only compile inputs annotated as having value semantics.") + placeholders.append(TensorPlaceholder(annotation[0], annotation[1])) + return placeholders diff --git a/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir b/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir new file mode 100644 index 000000000000..e281ac732b91 --- /dev/null +++ b/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir @@ -0,0 +1,17 @@ +// RUN: torch-mlir-opt -pass-pipeline='torch-function-to-torch-backend-pipeline{backend-legal-ops=torch.aten.square,torch.aten.argmax}' -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.square +func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + // CHECK: torch.aten.square + %0 = torch.aten.square %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.argmax +func.func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> { + %int0 = torch.constant.int 0 + %true = torch.constant.bool true + // CHECK: torch.aten.argmax + %0 = torch.aten.argmax %arg0, %int0, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],si64> + return %0 : !torch.vtensor<[1,?],si64> +}