Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FxImporter] Add backend lowering to Fx API #3288

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,13 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

from typing import Union, Optional, Sequence

import numpy as np
import torch
import torch.utils._pytree as pytree
from torch.export.graph_signature import OutputSpec, OutputKind
from torch.export import ExportedProgram

from torch_mlir import fx
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
lower_mlir_module,
OutputType,
)
from torch_mlir.torchscript import (
BACKEND_LEGAL_OPS,
_canon_extra_library,
)
from torch_mlir_e2e_test.configs.utils import (
recursively_convert_to_numpy,
recursively_convert_from_numpy,
Expand All @@ -39,53 +28,6 @@ def refine_result_type(_result):
raise ValueError(f"Unhandled return type {type(_result)}")


def jit(
prog: ExportedProgram,
func_name: str,
output_type: Union[str, "OutputType"] = OutputType.TORCH,
backend_legal_ops: Optional[Sequence[str]] = None,
extra_library=None,
verbose: bool = False,
):
if extra_library is None:
extra_library = []
mlir_module = None

extra_library_file_name = _canon_extra_library(extra_library)
output_type = OutputType.get(output_type)
if backend_legal_ops is not None:
if output_type != OutputType.TORCH:
raise Exception(
"`backend_legal_ops` is only valid with the " "`torch` output type"
)
backend_legal_ops = list(sorted(set(backend_legal_ops)))
else:
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])

option_string = (
"{backend-legal-ops="
+ ",".join(backend_legal_ops)
+ " extra-library="
+ extra_library_file_name
+ "}"
)

mlir_module = fx.export_and_import(prog, func_name=func_name)
assert mlir_module is not None
run_pipeline_with_repro_report(
mlir_module,
f"builtin.module(torch-simplification-pipeline)",
"Simplification pipeline for torch dialect",
)
run_pipeline_with_repro_report(
mlir_module,
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
"Lowering TorchFX IR -> Torch Backend IR",
)

return lower_mlir_module(verbose, output_type, mlir_module)


class FxImporterTestConfig(TestConfig):
"""TestConfig that runs the torch.nn.Module with Fx Importer"""

Expand All @@ -100,11 +42,11 @@ def compile(self, program: torch.nn.Module) -> torch.nn.Module:
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
result: Trace = []
for item in trace:
prog = torch.export.export(artifact, tuple(item.inputs))
module = jit(
prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs))
module = fx.export_and_import(
prog,
func_name=artifact.__class__.__name__,
output_type=self._output_type,
func_name=artifact.__class__.__name__,
)
module = self._backend.compile(module)
backend_module = self._backend.load(module)
Expand Down
50 changes: 48 additions & 2 deletions python/torch_mlir/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,58 @@
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
from torch_mlir.compiler_utils import (
OutputType,
run_pipeline_with_repro_report,
lower_mlir_module,
)


def _module_lowering(
verbose,
output_type,
torch_mod,
backend_legal_ops=None,
extra_library_file_name=None,
):

if output_type == OutputType.TORCH:
if verbose:
print(torch_mod)
return torch_mod
# TODO: pass backend_legal_ops/extra_library_file_name by caller
if backend_legal_ops is None:
backend_legal_ops = []
if extra_library_file_name is None:
extra_library_file_name = ""
option_string = (
"{backend-legal-ops="
+ ",".join(backend_legal_ops)
+ " extra-library="
+ extra_library_file_name
+ "}"
)
run_pipeline_with_repro_report(
torch_mod,
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
"Lowering TorchFX IR -> Torch Backend IR",
enable_ir_printing=verbose,
)
penguin-wwy marked this conversation as resolved.
Show resolved Hide resolved
return lower_mlir_module(verbose, output_type, torch_mod)


def export_and_import(
f: Union[nn.Module, ExportedProgram],
*args,
output_type: Union[str, OutputType] = OutputType.TORCH,
fx_importer: Optional[FxImporter] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
experimental_support_mutation: bool = False,
hooks: Optional[FxImporterHooks] = None,
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
func_name: str = "main",
enable_graph_printing: bool = False,
enable_ir_printing: bool = False,
**kwargs,
):
context = ir.Context()
Expand All @@ -52,15 +92,19 @@ def export_and_import(
else:
fx_importer.import_frozen_program(prog, func_name=func_name)

return fx_importer.module
return _module_lowering(
enable_ir_printing, OutputType.get(output_type), fx_importer.module
)


def stateless_fx_import(
gm: torch.fx.GraphModule,
output_type: Union[str, OutputType] = OutputType.TORCH,
fx_importer: Optional[FxImporter] = None,
hooks: Optional[FxImporterHooks] = None,
model_name: str = "main",
enable_graph_printing: bool = False,
enable_ir_printing: bool = False,
):
if enable_graph_printing:
gm.print_readable()
Expand All @@ -69,4 +113,6 @@ def stateless_fx_import(
if fx_importer is None:
fx_importer = FxImporter(context=context, hooks=hooks)
fx_importer.import_stateless_graph(gm.graph, func_name=model_name)
return fx_importer.module
return _module_lowering(
enable_ir_printing, OutputType.get(output_type), fx_importer.module
)
Loading