Skip to content

Commit

Permalink
[FxImporter] Add backend lowering to Fx API (llvm#3288)
Browse files Browse the repository at this point in the history
  • Loading branch information
penguin-wwy authored May 7, 2024
1 parent 6f911ba commit c3bd850
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,13 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

from typing import Union, Optional, Sequence

import numpy as np
import torch
import torch.utils._pytree as pytree
from torch.export.graph_signature import OutputSpec, OutputKind
from torch.export import ExportedProgram

from torch_mlir import fx
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
lower_mlir_module,
OutputType,
)
from torch_mlir.torchscript import (
BACKEND_LEGAL_OPS,
_canon_extra_library,
)
from torch_mlir_e2e_test.configs.utils import (
recursively_convert_to_numpy,
recursively_convert_from_numpy,
Expand All @@ -39,53 +28,6 @@ def refine_result_type(_result):
raise ValueError(f"Unhandled return type {type(_result)}")


def jit(
prog: ExportedProgram,
func_name: str,
output_type: Union[str, "OutputType"] = OutputType.TORCH,
backend_legal_ops: Optional[Sequence[str]] = None,
extra_library=None,
verbose: bool = False,
):
if extra_library is None:
extra_library = []
mlir_module = None

extra_library_file_name = _canon_extra_library(extra_library)
output_type = OutputType.get(output_type)
if backend_legal_ops is not None:
if output_type != OutputType.TORCH:
raise Exception(
"`backend_legal_ops` is only valid with the " "`torch` output type"
)
backend_legal_ops = list(sorted(set(backend_legal_ops)))
else:
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])

option_string = (
"{backend-legal-ops="
+ ",".join(backend_legal_ops)
+ " extra-library="
+ extra_library_file_name
+ "}"
)

mlir_module = fx.export_and_import(prog, func_name=func_name)
assert mlir_module is not None
run_pipeline_with_repro_report(
mlir_module,
f"builtin.module(torch-simplification-pipeline)",
"Simplification pipeline for torch dialect",
)
run_pipeline_with_repro_report(
mlir_module,
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
"Lowering TorchFX IR -> Torch Backend IR",
)

return lower_mlir_module(verbose, output_type, mlir_module)


class FxImporterTestConfig(TestConfig):
"""TestConfig that runs the torch.nn.Module with Fx Importer"""

Expand All @@ -100,11 +42,11 @@ def compile(self, program: torch.nn.Module) -> torch.nn.Module:
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
result: Trace = []
for item in trace:
prog = torch.export.export(artifact, tuple(item.inputs))
module = jit(
prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs))
module = fx.export_and_import(
prog,
func_name=artifact.__class__.__name__,
output_type=self._output_type,
func_name=artifact.__class__.__name__,
)
module = self._backend.compile(module)
backend_module = self._backend.load(module)
Expand Down
50 changes: 48 additions & 2 deletions python/torch_mlir/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,58 @@
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
from torch_mlir.compiler_utils import (
OutputType,
run_pipeline_with_repro_report,
lower_mlir_module,
)


def _module_lowering(
verbose,
output_type,
torch_mod,
backend_legal_ops=None,
extra_library_file_name=None,
):

if output_type == OutputType.TORCH:
if verbose:
print(torch_mod)
return torch_mod
# TODO: pass backend_legal_ops/extra_library_file_name by caller
if backend_legal_ops is None:
backend_legal_ops = []
if extra_library_file_name is None:
extra_library_file_name = ""
option_string = (
"{backend-legal-ops="
+ ",".join(backend_legal_ops)
+ " extra-library="
+ extra_library_file_name
+ "}"
)
run_pipeline_with_repro_report(
torch_mod,
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
"Lowering TorchFX IR -> Torch Backend IR",
enable_ir_printing=verbose,
)
return lower_mlir_module(verbose, output_type, torch_mod)


def export_and_import(
f: Union[nn.Module, ExportedProgram],
*args,
output_type: Union[str, OutputType] = OutputType.TORCH,
fx_importer: Optional[FxImporter] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
experimental_support_mutation: bool = False,
hooks: Optional[FxImporterHooks] = None,
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
func_name: str = "main",
enable_graph_printing: bool = False,
enable_ir_printing: bool = False,
**kwargs,
):
context = ir.Context()
Expand All @@ -52,15 +92,19 @@ def export_and_import(
else:
fx_importer.import_frozen_program(prog, func_name=func_name)

return fx_importer.module
return _module_lowering(
enable_ir_printing, OutputType.get(output_type), fx_importer.module
)


def stateless_fx_import(
gm: torch.fx.GraphModule,
output_type: Union[str, OutputType] = OutputType.TORCH,
fx_importer: Optional[FxImporter] = None,
hooks: Optional[FxImporterHooks] = None,
model_name: str = "main",
enable_graph_printing: bool = False,
enable_ir_printing: bool = False,
):
if enable_graph_printing:
gm.print_readable()
Expand All @@ -69,4 +113,6 @@ def stateless_fx_import(
if fx_importer is None:
fx_importer = FxImporter(context=context, hooks=hooks)
fx_importer.import_stateless_graph(gm.graph, func_name=model_name)
return fx_importer.module
return _module_lowering(
enable_ir_printing, OutputType.get(output_type), fx_importer.module
)

0 comments on commit c3bd850

Please sign in to comment.