Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --external-config option to tools/torchscript_e2e_test.sh #347

Merged
merged 1 commit into from
Oct 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions e2e_testing/torchscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import re
import sys

from torch_mlir_e2e_test.torchscript.framework import run_tests
from torch_mlir_e2e_test.torchscript.framework import TestConfig, run_tests
from torch_mlir_e2e_test.torchscript.reporting import report_results
from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY

Expand All @@ -20,7 +20,7 @@

from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend

from .xfail_sets import XFAIL_SETS
from .xfail_sets import XFAIL_SETS, COMMON_TORCH_MLIR_LOWERING_XFAILS

# Import tests to register them in the global registry.
# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
Expand All @@ -35,9 +35,7 @@
from . import reduction

def _get_argparse():
# TODO: Allow pulling in an out-of-tree backend, so downstream can easily
# plug into the e2e tests.
config_choices = ['native_torch', 'torchscript', 'refbackend']
config_choices = ['native_torch', 'torchscript', 'refbackend', 'external']
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('-c', '--config',
choices=config_choices,
Expand All @@ -47,6 +45,17 @@ def _get_argparse():
"refbackend": run through torch-mlir's RefBackend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
"external": use an external backend, specified by the `--external-backend` option.
''')
parser.add_argument('--external-config',
help=f'''
Specifies a path to a Python file, which will be `exec`'ed.
The file has the following contract:
- The global variable `config` should be set to an instance of `TestConfig`.
- `xfail_set` should be set to a set of test unique identifiers that are
expected to fail. The global `COMMON_TORCH_MLIR_LOWERING_XFAILS` provides
a common set of xfails that won't work on backends because torch-mlir
itself does not handle them.
''')
parser.add_argument('-f', '--filter', default='.*', help='''
Regular expression specifying which tests to include in this run.
Expand All @@ -71,10 +80,31 @@ def main():
# Find the selected config.
if args.config == 'refbackend':
config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend())
xfail_set = XFAIL_SETS['refbackend']
elif args.config == 'native_torch':
config = NativeTorchTestConfig()
xfail_set = XFAIL_SETS['native_torch']
elif args.config == 'torchscript':
config = TorchScriptTestConfig()
xfail_set = XFAIL_SETS['torchscript']
elif args.config == 'external':
with open(args.external_config, 'r') as f:
code = compile(f.read(), args.external_config, 'exec')
exec_globals = {
'COMMON_TORCH_MLIR_LOWERING_XFAILS': COMMON_TORCH_MLIR_LOWERING_XFAILS}
exec(code, exec_globals)
config = exec_globals.get('config')
xfail_set = exec_globals.get('xfail_set')
if config is None or not isinstance(config, TestConfig):
print(
f'ERROR: the script {args.external_config} did not set a global variable `config`'
)
sys.exit(1)
if xfail_set is None:
print(
f'ERROR: the script {args.external_config} did not set a global variable `xfail_set`'
)
sys.exit(1)

all_tests = list(GLOBAL_TEST_REGISTRY)
if args.serialized_test_dir:
Expand All @@ -101,7 +131,7 @@ def main():
results = run_tests(tests, config)

# Report the test results.
failed = report_results(results, XFAIL_SETS[args.config], args.verbose)
failed = report_results(results, xfail_set, args.verbose)
sys.exit(1 if failed else 0)

if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
# Lists of tests that fail to even reach the backends.
# These represent further work needed in torch-mlir to lower them properly
# to the backend contract.
_common_torch_mlir_lowering_xfails = {
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
'QuantizedMLP_basic',
}

XFAIL_SETS['refbackend'] = _common_torch_mlir_lowering_xfails
XFAIL_SETS['refbackend'] = COMMON_TORCH_MLIR_LOWERING_XFAILS

XFAIL_SETS['torchscript'] = {}

Expand Down