Skip to content

Commit

Permalink
[HPU]Enhance numba check (#345)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Nov 27, 2024
1 parent a7fcb51 commit 61e04a8
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 18 deletions.
21 changes: 3 additions & 18 deletions auto_round/export/export_to_itrex/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch.autograd import Function
from torch.nn import functional as F
import numpy as np
from auto_round.utils import logger
from auto_round.utils import logger, can_pack_with_numba

NF4 = [
-1.0,
Expand Down Expand Up @@ -446,24 +446,9 @@ def pack_array_with_numba(
# Try to pack with numba to accelerate the packing process.
# If numba is not availabll or the packing method is not supported,
# fallback to the torch implementation.
try:
import numba
if not can_pack_with_numba():
return self.pack_tensor_with_torch(torch.from_numpy(raw_array)).cpu().numpy()

numba.config.THREADING_LAYER = "safe"
except ImportError:
logger.warning(
"To accelerate packing, please install numba with `pip install numba tbb`."
)
return (
self.pack_tensor_with_torch(torch.from_numpy(raw_array)).cpu().numpy()
)
except Exception as e:
logger.warning(
f"Import numba failed with error: {e}, fallback to torch implementation."
)
return (
self.pack_tensor_with_torch(torch.from_numpy(raw_array)).cpu().numpy()
)
from auto_round.export.export_to_itrex.bit_packer import bit_packers

pack_func_name = (bits, compress_bits)
Expand Down
61 changes: 61 additions & 0 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,3 +980,64 @@ def compile_func(fun, device, enable_torch_compile):
return compile_func_on_cuda_or_cpu(fun, enable_torch_compile)



def is_numba_available(): # pragma: no cover
"""Check if Numba is available."""
try:
import numba

return True
except ImportError:
return False


def _is_tbb_installed(): # pragma: no cover
import importlib.metadata

try:
importlib.metadata.version("tbb")
return True
except importlib.metadata.PackageNotFoundError:
return False


def _is_tbb_configured(): # pragma: no cover
try:
from numba.np.ufunc.parallel import _check_tbb_version_compatible

# check if TBB is present and compatible
_check_tbb_version_compatible()

return True
except ImportError as e:
logger.warning_once(f"TBB not available: {e}")
return False


def is_tbb_available(): # pragma: no cover
"""Check if TBB is available."""
if not _is_tbb_installed():
logger.warning_once("TBB is not installed, please install it with `pip install tbb`.")
return False
if not _is_tbb_configured():
logger.warning_once(
(
"TBB is installed but not configured correctly. \n"
"Please add the TBB library path to `LD_LIBRARY_PATH`, "
"for example: `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/`."
)
)
return False
return True

def can_pack_with_numba(): # pragma: no cover
"""Check if Numba and TBB are available for packing.
To pack tensor with Numba, both Numba and TBB are required, and TBB should be configured correctly.
"""
if not is_numba_available():
logger.warning_once("Numba is not installed, please install it with `pip install numba`.")
return False
if not is_tbb_available():
return False
return True
19 changes: 19 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from unittest.mock import patch
import auto_round.utils as auto_round_utils

class TestPackingWithNumba:

@patch.object(auto_round_utils, "_is_tbb_installed", lambda: False)
def test_tbb_not_installed(self):
assert auto_round_utils.is_tbb_available() is False, "`is_tbb_available` should return False."
assert auto_round_utils.can_pack_with_numba() is False, "`can_pack_with_numba` should return False."

@patch.object(auto_round_utils, "_is_tbb_installed", lambda: True)
@patch.object(auto_round_utils, "_is_tbb_configured", lambda: False)
def test_tbb_installed_but_not_configured_right(self):
assert auto_round_utils.is_tbb_available() is False, "`is_tbb_available` should return False."
assert auto_round_utils.can_pack_with_numba() is False, "`can_pack_with_numba` should return False."

@patch.object(auto_round_utils, "is_numba_available", lambda: False)
def test_numba_not_installed(self):
assert auto_round_utils.can_pack_with_numba() is False, "`can_pack_with_numba` should return False."

0 comments on commit 61e04a8

Please sign in to comment.