From 40707a11256c82cd32e05cf8abce012226f47d22 Mon Sep 17 00:00:00 2001 From: yanang007 Date: Sun, 18 Aug 2024 20:36:40 +0800 Subject: [PATCH] =?UTF-8?q?:alien:=20=E7=A7=BB=E9=99=A4np.sctypes=E5=BC=95?= =?UTF-8?q?=E7=94=A8=E4=BB=A5=E9=80=82=E9=85=8DNumpy=202.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit :white_check_mark: 添加dhash稳定性测试 --- metalpy/utils/dhash.py | 48 ++++++++++++++++++--- metalpy/utils/tests/test_dhash.py | 71 +++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+), 5 deletions(-) create mode 100644 metalpy/utils/tests/test_dhash.py diff --git a/metalpy/utils/dhash.py b/metalpy/utils/dhash.py index 520fe82..9ad46a8 100644 --- a/metalpy/utils/dhash.py +++ b/metalpy/utils/dhash.py @@ -29,6 +29,7 @@ def wrapper(func): for t in hashed_types: DHash.hashers[t] = func return func + return wrapper @@ -36,6 +37,7 @@ def register_lazy_dhasher(hashed_type: str): def wrapper(func): DHash.hasher_creators[hashed_type] = func return func + return wrapper @@ -104,7 +106,11 @@ def __hash__(self): return self.digest() def __eq__(self, other): - return isinstance(other, DHash) and self.result == other.result + if isinstance(other, DHash): + val = other.result + else: + val = other + return val == self.result @staticmethod def convert_to_dhashable(obj): @@ -118,7 +124,9 @@ def convert_to_dhashable(obj): hasher = DHash._find_lazy_hasher(t) if hasher is None: - def fallback(x): DHash(x, convert=False) + def fallback(x): + DHash(x, convert=False) + hasher = getattr(t, '__dhash__', fallback) if hasher == fallback: warnings.warn(f'Cannot find dhasher for type `{t.__name__}`,' @@ -140,8 +148,38 @@ def _find_lazy_hasher(t): return hasher -@register_dhasher(*itertools.chain(*[np.sctypes[k] for k in np.sctypes if k != 'others'])) -@register_dhasher(float, int, bool) +def _check_dhashable(t: type): + """初步检查类型是否支持确定性哈希(只检查零值的哈希值是否唯一,即存在自行定义的哈希函数) + """ + d_hashable = False + try: + a, b = t(), t() + d_hashable = hash(a) == hash(b) + except TypeError: + d_hashable = False + finally: + return d_hashable + + +# 移除 np.sctypes 引用,改为硬编码基础类型检查结合自动获取其他类型 +# reference: https://github.com/numpy/numpy/issues/26778 +NUMPY_BASE_TYPES = { + np.int8, np.int16, np.int32, np.int64, np.intp, np.intc, + np.uint8, np.uint16, np.uint32, np.uint64, np.uintp, np.uintc, + np.float16, np.float32, np.float64, + np.complex64, np.complex128, + np.datetime64, np.timedelta64, +} # 直接检查保证基本类型存在(默认支持确定性哈希) + +NUMPY_SUPPORTED_TYPES = { + t + for t in np.sctypeDict.values() + if t.__module__ == np.__name__ and _check_dhashable(t) +} | NUMPY_BASE_TYPES # 检查其他支持确定性哈希的类型 + + +@register_dhasher(*NUMPY_SUPPORTED_TYPES) +@register_dhasher(bool, int, float, complex) def _hash_basic_type(obj): return DHash(obj, convert=False) @@ -196,7 +234,7 @@ def _hash_serializable(obj): def _hash_array(arr: np.ndarray, n_samples=10, sparse=False): if not sparse: arr = arr.ravel() - rand = np.random.RandomState(int(arr[len(arr) // 2]) % 2**32) + rand = np.random.RandomState(int(arr[len(arr) // 2]) % 2 ** 32) n_samples = min(n_samples, len(arr)) # 加入shape作为参数防止类似np.ones(100)和np.ones(1000)的冲突 diff --git a/metalpy/utils/tests/test_dhash.py b/metalpy/utils/tests/test_dhash.py new file mode 100644 index 0000000..3c0def6 --- /dev/null +++ b/metalpy/utils/tests/test_dhash.py @@ -0,0 +1,71 @@ +import pathlib + +import numpy as np +import pandas as pd + +from metalpy.utils.dhash import dhash + + +def _dummy(x): + return x + 225 + + +def test_dhash_stable(): + """验证dhash的稳定性(在不同次运行中哈希结果一致) + """ + dhash_records = [ + # Basic Python types + (225, 7846919982185809875), + (2.25, -3918227367074148393), + (2 + 25j, -6412466123961202416), + (True, 325184529882986853), + (None, -7934642784484250388), + (b'225', 5436264356540408418), # bytes + ('0225', 6433453438383710489), + + # Python collections + ((2, '2', 5.0), -5582375266051148848), + ([2, '2', 5.0], -5582375266051148848), + ({0: 2.0, np.int8(2): '5'}, 1876027385350813336), + + # Python functions and classes + (_dummy, 6449074981454838072), + (int, -5394466986204997707), + + # Numpy types + (np.int8(25), -5921789015889676150), + (np.uint8(225), 7846919982185809875), + (np.float32(225), 7846919982185809875), + (np.int16(225), 7846919982185809875), + (np.uint16(225), 7846919982185809875), + (np.float64(225), 7846919982185809875), + (np.intc(225), 7846919982185809875), + (np.uintc(225), 7846919982185809875), + (np.longdouble(225), 7846919982185809875), + (np.int32(225), 7846919982185809875), + (np.uint32(225), 7846919982185809875), + (np.complex64(225), 7846919982185809875), + (np.float16(225), 7846919982185809875), + (np.int64(225), 7846919982185809875), + (np.uint64(225), 7846919982185809875), + (np.complex128(2 + 25j), -6412466123961202416), + (np.bool_(True), 325184529882986853), + (np.array([2, 25]), -5297488126673252258), + + # Numpy datetime types + (np.datetime64('2021-12-25'), -8803879579640539530), + (np.timedelta64(5, 'D'), 3974400344764180608), + + # Pandas types + (pd.Series([2, 25]), -5297488126673252258), + ] + + # Pathlib types + try: + dhash_records.append((pathlib.WindowsPath('C:\\Users'), 6348354964056181059)) + except NotImplementedError: + dhash_records.append((pathlib.PosixPath('/home'), 1298416137799441881)) + + for k, v in dhash_records: + hashed = dhash(k) + assert dhash(k) == v, f"Unexpected dhash result for {k}: got {hashed.result} (expected {v})"