Skip to content

Commit

Permalink
👽 移除np.sctypes引用以适配Numpy 2.0
Browse files Browse the repository at this point in the history
✅ 添加dhash稳定性测试
  • Loading branch information
yanang007 committed Aug 18, 2024
1 parent 299e397 commit 40707a1
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 5 deletions.
48 changes: 43 additions & 5 deletions metalpy/utils/dhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ def wrapper(func):
for t in hashed_types:
DHash.hashers[t] = func
return func

return wrapper


def register_lazy_dhasher(hashed_type: str):
def wrapper(func):
DHash.hasher_creators[hashed_type] = func
return func

return wrapper


Expand Down Expand Up @@ -104,7 +106,11 @@ def __hash__(self):
return self.digest()

def __eq__(self, other):
return isinstance(other, DHash) and self.result == other.result
if isinstance(other, DHash):
val = other.result
else:
val = other
return val == self.result

@staticmethod
def convert_to_dhashable(obj):
Expand All @@ -118,7 +124,9 @@ def convert_to_dhashable(obj):
hasher = DHash._find_lazy_hasher(t)

if hasher is None:
def fallback(x): DHash(x, convert=False)
def fallback(x):
DHash(x, convert=False)

hasher = getattr(t, '__dhash__', fallback)
if hasher == fallback:
warnings.warn(f'Cannot find dhasher for type `{t.__name__}`,'
Expand All @@ -140,8 +148,38 @@ def _find_lazy_hasher(t):
return hasher


@register_dhasher(*itertools.chain(*[np.sctypes[k] for k in np.sctypes if k != 'others']))
@register_dhasher(float, int, bool)
def _check_dhashable(t: type):
"""初步检查类型是否支持确定性哈希(只检查零值的哈希值是否唯一,即存在自行定义的哈希函数)
"""
d_hashable = False
try:
a, b = t(), t()
d_hashable = hash(a) == hash(b)
except TypeError:
d_hashable = False
finally:
return d_hashable


# 移除 np.sctypes 引用,改为硬编码基础类型检查结合自动获取其他类型
# reference: https://github.com/numpy/numpy/issues/26778
NUMPY_BASE_TYPES = {
np.int8, np.int16, np.int32, np.int64, np.intp, np.intc,
np.uint8, np.uint16, np.uint32, np.uint64, np.uintp, np.uintc,
np.float16, np.float32, np.float64,
np.complex64, np.complex128,
np.datetime64, np.timedelta64,
} # 直接检查保证基本类型存在(默认支持确定性哈希)

NUMPY_SUPPORTED_TYPES = {
t
for t in np.sctypeDict.values()
if t.__module__ == np.__name__ and _check_dhashable(t)
} | NUMPY_BASE_TYPES # 检查其他支持确定性哈希的类型


@register_dhasher(*NUMPY_SUPPORTED_TYPES)
@register_dhasher(bool, int, float, complex)
def _hash_basic_type(obj):
return DHash(obj, convert=False)

Expand Down Expand Up @@ -196,7 +234,7 @@ def _hash_serializable(obj):
def _hash_array(arr: np.ndarray, n_samples=10, sparse=False):
if not sparse:
arr = arr.ravel()
rand = np.random.RandomState(int(arr[len(arr) // 2]) % 2**32)
rand = np.random.RandomState(int(arr[len(arr) // 2]) % 2 ** 32)
n_samples = min(n_samples, len(arr))

# 加入shape作为参数防止类似np.ones(100)和np.ones(1000)的冲突
Expand Down
71 changes: 71 additions & 0 deletions metalpy/utils/tests/test_dhash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pathlib

import numpy as np
import pandas as pd

from metalpy.utils.dhash import dhash


def _dummy(x):
return x + 225


def test_dhash_stable():
"""验证dhash的稳定性(在不同次运行中哈希结果一致)
"""
dhash_records = [
# Basic Python types
(225, 7846919982185809875),
(2.25, -3918227367074148393),
(2 + 25j, -6412466123961202416),
(True, 325184529882986853),
(None, -7934642784484250388),
(b'225', 5436264356540408418), # bytes
('0225', 6433453438383710489),

# Python collections
((2, '2', 5.0), -5582375266051148848),
([2, '2', 5.0], -5582375266051148848),
({0: 2.0, np.int8(2): '5'}, 1876027385350813336),

# Python functions and classes
(_dummy, 6449074981454838072),
(int, -5394466986204997707),

# Numpy types
(np.int8(25), -5921789015889676150),
(np.uint8(225), 7846919982185809875),
(np.float32(225), 7846919982185809875),
(np.int16(225), 7846919982185809875),
(np.uint16(225), 7846919982185809875),
(np.float64(225), 7846919982185809875),
(np.intc(225), 7846919982185809875),
(np.uintc(225), 7846919982185809875),
(np.longdouble(225), 7846919982185809875),
(np.int32(225), 7846919982185809875),
(np.uint32(225), 7846919982185809875),
(np.complex64(225), 7846919982185809875),
(np.float16(225), 7846919982185809875),
(np.int64(225), 7846919982185809875),
(np.uint64(225), 7846919982185809875),
(np.complex128(2 + 25j), -6412466123961202416),
(np.bool_(True), 325184529882986853),
(np.array([2, 25]), -5297488126673252258),

# Numpy datetime types
(np.datetime64('2021-12-25'), -8803879579640539530),
(np.timedelta64(5, 'D'), 3974400344764180608),

# Pandas types
(pd.Series([2, 25]), -5297488126673252258),
]

# Pathlib types
try:
dhash_records.append((pathlib.WindowsPath('C:\\Users'), 6348354964056181059))
except NotImplementedError:
dhash_records.append((pathlib.PosixPath('/home'), 1298416137799441881))

for k, v in dhash_records:
hashed = dhash(k)
assert dhash(k) == v, f"Unexpected dhash result for {k}: got {hashed.result} (expected {v})"

0 comments on commit 40707a1

Please sign in to comment.