diff --git a/.github/workflows/script/install_binary.sh b/.github/workflows/script/install_binary.sh index bbd6b7df2f1..7bca0d4d2f3 100644 --- a/.github/workflows/script/install_binary.sh +++ b/.github/workflows/script/install_binary.sh @@ -4,6 +4,7 @@ source /intel-extension-for-transformers/.github/workflows/script/change_color.s cd /intel-extension-for-transformers export CMAKE_ARGS="-DNE_DNNL_CACHE_DIR=/cache" pip install -U pip +pip install -r requirements.txt $BOLD_YELLOW && echo "---------------- git submodule update --init --recursive -------------" && $RESET git config --global --add safe.directory "*" git submodule update --init --recursive diff --git a/intel_extension_for_transformers/qbits/__init__.py b/intel_extension_for_transformers/qbits/__init__.py index c23599090dc..5cd39d26a7f 100644 --- a/intel_extension_for_transformers/qbits/__init__.py +++ b/intel_extension_for_transformers/qbits/__init__.py @@ -16,5 +16,6 @@ # limitations under the License. import torch -if not torch.xpu._is_compiled(): - from intel_extension_for_transformers.qbits_py import * # pylint: disable=E0401, E0611 +import intel_extension_for_transformers +if "gpu" not in intel_extension_for_transformers.__version__: + from intel_extension_for_transformers.qbits_py import * # pylint: disable=E0401, E0611 diff --git a/setup.py b/setup.py index 17700afeeb3..13aec7b7025 100644 --- a/setup.py +++ b/setup.py @@ -8,10 +8,12 @@ from pathlib import Path from setuptools import Extension, find_packages, setup from setuptools.command.build_ext import build_ext +from setuptools_scm import get_version result = subprocess.Popen("pip install -r requirements.txt", shell=True) result.wait() + def is_intel_gpu_available(): import torch import intel_extension_for_pytorch as ipex @@ -286,6 +288,9 @@ def check_submodules(): "intel_extension_for_transformers/transformers/runtime/"), ]) cmdclass = {'build_ext': CMakeBuild} + itrex_version = get_version() + if IS_INTEL_GPU: + itrex_version = itrex_version + "-gpu" setup( name="intel-extension-for-transformers", @@ -324,4 +329,5 @@ def check_submodules(): ], setup_requires=['setuptools_scm'], use_scm_version=True, + version=itrex_version )