-
Notifications
You must be signed in to change notification settings - Fork 517
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
See RFC #999 Co-authored-by: Bairen Yi [email protected] Co-authored-by: Jiawei Wu [email protected] Co-authored-by: Tianyou Guo [email protected] Co-authored-by: Xu Yan [email protected] Co-authored-by: Ziheng Jiang [email protected]
- Loading branch information
Tanyo Kwok
committed
Aug 14, 2022
1 parent
41aa562
commit 88a46e7
Showing
9 changed files
with
298 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
# Also available under a BSD-style license. See LICENSE. | ||
|
||
import abc | ||
from typing import TypeVar | ||
|
||
import torch | ||
|
||
from torch_mlir.ir import Module | ||
|
||
# A type shared between the result of `MhloBackend.compile` and the | ||
# input to `MhloBackend.load`. Each backend will likely have a | ||
# different definition of this type. | ||
CompiledArtifact = TypeVar('CompiledArtifact') | ||
|
||
# A wrapper around a backend-specific loaded program representation | ||
# that uniformly translates the `x.method(...)` interface expected of | ||
# Torch modules into appropriate lower-level operations. | ||
Invoker = TypeVar('Invoker') | ||
|
||
|
||
class MhloBackend(abc.ABC): | ||
"""The interface to an MHLO backend. | ||
Backends are recommended to raise meaningful exceptions in case of error, | ||
ideally with easy reproduction instructions. | ||
""" | ||
@abc.abstractmethod | ||
def compile(self, module: Module) -> CompiledArtifact: | ||
"""Compile the provided MLIR module into a compiled artifact. | ||
The module adheres to the MHLO backend contract | ||
(see the VerifyMhloBackendContract pass). | ||
The compiled artifact can be any type, but must be correctly | ||
interpreted by the `load` method. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def load(self, artifact: CompiledArtifact) -> Invoker: | ||
"""Load the compiled artifact into a uniformly invokable form. | ||
The compiled artifact is the result of a previous call to `compile`. | ||
See the description of `Invoker` for the requirements on the returned | ||
type. | ||
""" |
45 changes: 45 additions & 0 deletions
45
python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
# Also available under a BSD-style license. See LICENSE. | ||
|
||
from torch_mlir.ir import * | ||
from torch_mlir.passmanager import * | ||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report | ||
|
||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend | ||
|
||
from .abc import MhloBackend | ||
|
||
__all__ = [ | ||
"LinalgOnTensorsMhloBackend", | ||
] | ||
|
||
class LinalgOnTensorsMhloBackend(MhloBackend): | ||
"""Main entry-point for the linalg-on-tensors based MHLO backend. | ||
This currently uses the linalg-on-tensors RefBackend for actual execution. | ||
""" | ||
def __init__(self): | ||
super().__init__() | ||
self.refbackend = RefBackendLinalgOnTensorsBackend() | ||
|
||
def compile(self, imported_module: Module): | ||
"""Compiles an imported module that satisfied the MHLO backend contract. | ||
Args: | ||
imported_module: The MLIR module consisting of funcs in the MHLO | ||
dialect. | ||
Returns: | ||
An opaque, backend specific compiled artifact object that can be | ||
passed to `load`. | ||
""" | ||
run_pipeline_with_repro_report( | ||
imported_module, | ||
"func.func(hlo-legalize-to-linalg)", | ||
"Lowering MLIR-HLO to Linalg-on-Tensors") | ||
return self.refbackend.compile(imported_module) | ||
|
||
def load(self, module): | ||
"""Loads a compiled artifact into the runtime.""" | ||
return self.refbackend.load(module) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
60 changes: 60 additions & 0 deletions
60
python/torch_mlir_e2e_test/torchscript/configs/mhlo_backend.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
# Also available under a BSD-style license. See LICENSE. | ||
|
||
import sys | ||
from typing import Any | ||
from io import StringIO | ||
import os | ||
import tempfile | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from torch_mlir_e2e_test.mhlo_backends.abc import MhloBackend | ||
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem | ||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report | ||
from .utils import ( | ||
recursively_convert_to_numpy, | ||
recursively_convert_from_numpy, | ||
convert_torchscript_module_to_torch_backend_contract_mlir, | ||
) | ||
|
||
|
||
class MhloBackendTestConfig(TestConfig): | ||
"""Base class for TestConfig's that are implemented with linalg-on-tensors. | ||
This class handles all the common lowering that torch-mlir does before | ||
reaching the linalg-on-tensors abstraction level. | ||
""" | ||
def __init__(self, backend: MhloBackend): | ||
super().__init__() | ||
self.backend = backend | ||
|
||
def compile(self, program: torch.nn.Module) -> Any: | ||
|
||
module = convert_torchscript_module_to_torch_backend_contract_mlir( | ||
program) | ||
|
||
run_pipeline_with_repro_report( | ||
module, | ||
"torch-backend-to-mhlo-backend-pipeline", | ||
"Lower Torch Backend IR -> MHLO Backend IR") | ||
|
||
return self.backend.compile(module) | ||
|
||
|
||
|
||
def run(self, artifact: Any, trace: Trace) -> Trace: | ||
backend_module = self.backend.load(artifact) | ||
result: Trace = [] | ||
for item in trace: | ||
numpy_inputs = recursively_convert_to_numpy(item.inputs) | ||
outputs = getattr(backend_module, item.symbol)(*numpy_inputs) | ||
output = recursively_convert_from_numpy(outputs) | ||
result.append( | ||
TraceItem(symbol=item.symbol, | ||
inputs=item.inputs, | ||
output=output)) | ||
return result |