From 7c014352cd21362ab62e04d2abe16ecbc0e91517 Mon Sep 17 00:00:00 2001 From: penguin-wwy <940375606@qq.com> Date: Mon, 6 May 2024 14:21:45 +0800 Subject: [PATCH] [FxImporter] Add backend lowering to Fx API --- .../configs/fx_importer_backend.py | 64 +------------------ python/torch_mlir/fx.py | 50 ++++++++++++++- 2 files changed, 51 insertions(+), 63 deletions(-) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py index 204ddf61674b..2a63c06bdc37 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py @@ -3,8 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from typing import Union, Optional, Sequence - import numpy as np import torch import torch.utils._pytree as pytree @@ -12,15 +10,6 @@ from torch.export import ExportedProgram from torch_mlir import fx -from torch_mlir.compiler_utils import ( - run_pipeline_with_repro_report, - lower_mlir_module, - OutputType, -) -from torch_mlir.torchscript import ( - BACKEND_LEGAL_OPS, - _canon_extra_library, -) from torch_mlir_e2e_test.configs.utils import ( recursively_convert_to_numpy, recursively_convert_from_numpy, @@ -39,53 +28,6 @@ def refine_result_type(_result): raise ValueError(f"Unhandled return type {type(_result)}") -def jit( - prog: ExportedProgram, - func_name: str, - output_type: Union[str, "OutputType"] = OutputType.TORCH, - backend_legal_ops: Optional[Sequence[str]] = None, - extra_library=None, - verbose: bool = False, -): - if extra_library is None: - extra_library = [] - mlir_module = None - - extra_library_file_name = _canon_extra_library(extra_library) - output_type = OutputType.get(output_type) - if backend_legal_ops is not None: - if output_type != OutputType.TORCH: - raise Exception( - "`backend_legal_ops` is only valid with the " "`torch` output type" - ) - backend_legal_ops = list(sorted(set(backend_legal_ops))) - else: - backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, []) - - option_string = ( - "{backend-legal-ops=" - + ",".join(backend_legal_ops) - + " extra-library=" - + extra_library_file_name - + "}" - ) - - mlir_module = fx.export_and_import(prog, func_name=func_name) - assert mlir_module is not None - run_pipeline_with_repro_report( - mlir_module, - f"builtin.module(torch-simplification-pipeline)", - "Simplification pipeline for torch dialect", - ) - run_pipeline_with_repro_report( - mlir_module, - f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})", - "Lowering TorchFX IR -> Torch Backend IR", - ) - - return lower_mlir_module(verbose, output_type, mlir_module) - - class FxImporterTestConfig(TestConfig): """TestConfig that runs the torch.nn.Module with Fx Importer""" @@ -100,11 +42,11 @@ def compile(self, program: torch.nn.Module) -> torch.nn.Module: def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: result: Trace = [] for item in trace: - prog = torch.export.export(artifact, tuple(item.inputs)) - module = jit( + prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs)) + module = fx.export_and_import( prog, - func_name=artifact.__class__.__name__, output_type=self._output_type, + func_name=artifact.__class__.__name__, ) module = self._backend.compile(module) backend_module = self._backend.load(module) diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 651ccae673a6..8d5c5cb1125c 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -16,11 +16,50 @@ from torch_mlir import ir from torch_mlir.dialects import torch as torch_d from torch_mlir.extras.fx_decomp_util import get_decomposition_table +from torch_mlir.compiler_utils import ( + OutputType, + run_pipeline_with_repro_report, + lower_mlir_module, +) + + +def _module_lowering( + verbose, + output_type, + torch_mod, + backend_legal_ops=None, + extra_library_file_name=None, +): + + if output_type == OutputType.TORCH: + if verbose: + print(torch_mod) + return torch_mod + # TODO: pass backend_legal_ops/extra_library_file_name by caller + if backend_legal_ops is None: + backend_legal_ops = [] + if extra_library_file_name is None: + extra_library_file_name = "" + option_string = ( + "{backend-legal-ops=" + + ",".join(backend_legal_ops) + + " extra-library=" + + extra_library_file_name + + "}" + ) + run_pipeline_with_repro_report( + torch_mod, + f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})", + "Lowering TorchFX IR -> Torch Backend IR", + enable_ir_printing=verbose, + ) + return lower_mlir_module(verbose, output_type, torch_mod) def export_and_import( f: Union[nn.Module, ExportedProgram], *args, + output_type: Union[str, OutputType] = OutputType.TORCH, fx_importer: Optional[FxImporter] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, experimental_support_mutation: bool = False, @@ -28,6 +67,7 @@ def export_and_import( decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None, func_name: str = "main", enable_graph_printing: bool = False, + enable_ir_printing: bool = False, **kwargs, ): context = ir.Context() @@ -52,15 +92,19 @@ def export_and_import( else: fx_importer.import_frozen_program(prog, func_name=func_name) - return fx_importer.module + return _module_lowering( + enable_ir_printing, OutputType.get(output_type), fx_importer.module + ) def stateless_fx_import( gm: torch.fx.GraphModule, + output_type: Union[str, OutputType] = OutputType.TORCH, fx_importer: Optional[FxImporter] = None, hooks: Optional[FxImporterHooks] = None, model_name: str = "main", enable_graph_printing: bool = False, + enable_ir_printing: bool = False, ): if enable_graph_printing: gm.print_readable() @@ -69,4 +113,6 @@ def stateless_fx_import( if fx_importer is None: fx_importer = FxImporter(context=context, hooks=hooks) fx_importer.import_stateless_graph(gm.graph, func_name=model_name) - return fx_importer.module + return _module_lowering( + enable_ir_printing, OutputType.get(output_type), fx_importer.module + )