-
Notifications
You must be signed in to change notification settings - Fork 519
/
Copy pathupdate_torch_ods.sh
executable file
·52 lines (45 loc) · 2 KB
/
update_torch_ods.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#!/bin/bash
# Updates auto-generated ODS files for the `torch` dialect.
#
# Environment variables:
# TORCH_MLIR_EXT_MODULES: comma-separated list of python module names
# which register custom PyTorch operators upon being imported.
# TORCH_MLIR_EXT_PYTHONPATH: colon-separated list of paths necessary
# for importing PyTorch extensions specified in TORCH_MLIR_EXT_MODULES.
# For more information on supporting custom operators, see:
# ${TORCH_MLIR}/python/torch_mlir/_torch_mlir_custom_op_example/README.md
set -euo pipefail
src_dir="$(realpath "$(dirname "$0")"/..)"
build_dir="$(realpath "${TORCH_MLIR_BUILD_DIR:-$src_dir/build}")"
torch_ir_include_dir="${src_dir}/include/torch-mlir/Dialect/Torch/IR"
in_tree_pkg_dir="${build_dir}/tools/torch-mlir/python_packages"
out_of_tree_pkg_dir="${build_dir}/python_packages"
if [[ ! -d "${in_tree_pkg_dir}" && ! -d "${out_of_tree_pkg_dir}" ]]; then
echo "Couldn't find in-tree or out-of-tree build, exiting."
exit 1
fi
# The `-nt` check works even if one of the two directories is missing.
if [[ "${in_tree_pkg_dir}" -nt "${out_of_tree_pkg_dir}" ]]; then
python_packages_dir="${in_tree_pkg_dir}"
else
python_packages_dir="${out_of_tree_pkg_dir}"
fi
TORCH_MLIR_EXT_PYTHONPATH="${TORCH_MLIR_EXT_PYTHONPATH:-""}"
pypath="${python_packages_dir}/torch_mlir"
if [ ! -z ${TORCH_MLIR_EXT_PYTHONPATH} ]; then
pypath="${pypath}:${TORCH_MLIR_EXT_PYTHONPATH}"
fi
TORCH_MLIR_EXT_MODULES="${TORCH_MLIR_EXT_MODULES:-""}"
ext_module="${ext_module:-""}"
if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
ext_module="${TORCH_MLIR_EXT_MODULES}"
fi
set +u
# To enable this python package, manually build torch_mlir with:
# -DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS=ON
# TODO: move this package out of JIT_IR_IMPORTER.
PYTHONPATH="${PYTHONPATH}:${pypath}" python3 \
-m torch_mlir.jit_ir_importer.build_tools.torch_ods_gen \
--torch_ir_include_dir="${torch_ir_include_dir}" \
--pytorch_op_extensions="${ext_module}" \
--debug_registry_dump="${torch_ir_include_dir}/JITOperatorRegistryDump.txt"