From f271445c0ac34bad47704ac500834ae961849517 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Thu, 20 Oct 2022 17:32:49 +0100 Subject: [PATCH] ENH: mask support for hastable functions for indexing (#48396) * ENH: mask support for hastable functions for indexing * Fix mypy * Adjust test * Add comment * Add docstring * Refactor into own functions * Fix typing --- pandas/_libs/hashtable.pxd | 48 +++++++++++ pandas/_libs/hashtable.pyi | 8 +- pandas/_libs/hashtable_class_helper.pxi.in | 99 +++++++++++++++++++--- pandas/tests/libs/test_hashtable.py | 78 +++++++++++++++++ scripts/run_stubtest.py | 2 + 5 files changed, 219 insertions(+), 16 deletions(-) diff --git a/pandas/_libs/hashtable.pxd b/pandas/_libs/hashtable.pxd index 80d7ab58dc559..b32bd4880588d 100644 --- a/pandas/_libs/hashtable.pxd +++ b/pandas/_libs/hashtable.pxd @@ -41,75 +41,123 @@ cdef class HashTable: cdef class UInt64HashTable(HashTable): cdef kh_uint64_t *table + cdef int64_t na_position + cdef bint uses_mask cpdef get_item(self, uint64_t val) cpdef set_item(self, uint64_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class Int64HashTable(HashTable): cdef kh_int64_t *table + cdef int64_t na_position + cdef bint uses_mask cpdef get_item(self, int64_t val) cpdef set_item(self, int64_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class UInt32HashTable(HashTable): cdef kh_uint32_t *table + cdef int64_t na_position + cdef bint uses_mask cpdef get_item(self, uint32_t val) cpdef set_item(self, uint32_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class Int32HashTable(HashTable): cdef kh_int32_t *table + cdef int64_t na_position + cdef bint uses_mask cpdef get_item(self, int32_t val) cpdef set_item(self, int32_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class UInt16HashTable(HashTable): cdef kh_uint16_t *table + cdef int64_t na_position + cdef bint uses_mask cpdef get_item(self, uint16_t val) cpdef set_item(self, uint16_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class Int16HashTable(HashTable): cdef kh_int16_t *table + cdef int64_t na_position + cdef bint uses_mask cpdef get_item(self, int16_t val) cpdef set_item(self, int16_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class UInt8HashTable(HashTable): cdef kh_uint8_t *table + cdef int64_t na_position + cdef bint uses_mask cpdef get_item(self, uint8_t val) cpdef set_item(self, uint8_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class Int8HashTable(HashTable): cdef kh_int8_t *table + cdef int64_t na_position + cdef bint uses_mask cpdef get_item(self, int8_t val) cpdef set_item(self, int8_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class Float64HashTable(HashTable): cdef kh_float64_t *table + cdef int64_t na_position + cdef bint uses_mask cpdef get_item(self, float64_t val) cpdef set_item(self, float64_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class Float32HashTable(HashTable): cdef kh_float32_t *table + cdef int64_t na_position + cdef bint uses_mask cpdef get_item(self, float32_t val) cpdef set_item(self, float32_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class Complex64HashTable(HashTable): cdef kh_complex64_t *table + cdef int64_t na_position + cdef bint uses_mask cpdef get_item(self, complex64_t val) cpdef set_item(self, complex64_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class Complex128HashTable(HashTable): cdef kh_complex128_t *table + cdef int64_t na_position + cdef bint uses_mask cpdef get_item(self, complex128_t val) cpdef set_item(self, complex128_t key, Py_ssize_t val) + cpdef get_na(self) + cpdef set_na(self, Py_ssize_t val) cdef class PyObjectHashTable(HashTable): cdef kh_pymap_t *table diff --git a/pandas/_libs/hashtable.pyi b/pandas/_libs/hashtable.pyi index 06ff1041d3cf7..e60ccdb29c6b2 100644 --- a/pandas/_libs/hashtable.pyi +++ b/pandas/_libs/hashtable.pyi @@ -110,16 +110,18 @@ class ObjectVector: class HashTable: # NB: The base HashTable class does _not_ actually have these methods; - # we are putting the here for the sake of mypy to avoid + # we are putting them here for the sake of mypy to avoid # reproducing them in each subclass below. - def __init__(self, size_hint: int = ...) -> None: ... + def __init__(self, size_hint: int = ..., uses_mask: bool = ...) -> None: ... def __len__(self) -> int: ... def __contains__(self, key: Hashable) -> bool: ... def sizeof(self, deep: bool = ...) -> int: ... def get_state(self) -> dict[str, int]: ... # TODO: `item` type is subclass-specific def get_item(self, item): ... # TODO: return type? - def set_item(self, item) -> None: ... + def set_item(self, item, val) -> None: ... + def get_na(self): ... # TODO: return type? + def set_na(self, val) -> None: ... def map_locations( self, values: np.ndarray, # np.ndarray[subclass-specific] diff --git a/pandas/_libs/hashtable_class_helper.pxi.in b/pandas/_libs/hashtable_class_helper.pxi.in index 54260a9a90964..c6d8783d6f115 100644 --- a/pandas/_libs/hashtable_class_helper.pxi.in +++ b/pandas/_libs/hashtable_class_helper.pxi.in @@ -396,13 +396,16 @@ dtypes = [('Complex128', 'complex128', 'khcomplex128_t', 'to_khcomplex128_t'), cdef class {{name}}HashTable(HashTable): - def __cinit__(self, int64_t size_hint=1): + def __cinit__(self, int64_t size_hint=1, bint uses_mask=False): self.table = kh_init_{{dtype}}() size_hint = min(kh_needed_n_buckets(size_hint), SIZE_HINT_LIMIT) kh_resize_{{dtype}}(self.table, size_hint) + self.uses_mask = uses_mask + self.na_position = -1 + def __len__(self) -> int: - return self.table.size + return self.table.size + (0 if self.na_position == -1 else 1) def __dealloc__(self): if self.table is not NULL: @@ -410,9 +413,15 @@ cdef class {{name}}HashTable(HashTable): self.table = NULL def __contains__(self, object key) -> bool: + # The caller is responsible to check for compatible NA values in case + # of masked arrays. cdef: khiter_t k {{c_type}} ckey + + if self.uses_mask and checknull(key): + return -1 != self.na_position + ckey = {{to_c_type}}(key) k = kh_get_{{dtype}}(self.table, ckey) return k != self.table.n_buckets @@ -435,10 +444,24 @@ cdef class {{name}}HashTable(HashTable): } cpdef get_item(self, {{dtype}}_t val): + """Extracts the position of val from the hashtable. + + Parameters + ---------- + val : Scalar + The value that is looked up in the hashtable + + Returns + ------- + The position of the requested integer. + """ + # Used in core.sorting, IndexEngine.get_loc + # Caller is responsible for checking for pd.NA cdef: khiter_t k {{c_type}} cval + cval = {{to_c_type}}(val) k = kh_get_{{dtype}}(self.table, cval) if k != self.table.n_buckets: @@ -446,12 +469,29 @@ cdef class {{name}}HashTable(HashTable): else: raise KeyError(val) + cpdef get_na(self): + """Extracts the position of na_value from the hashtable. + + Returns + ------- + The position of the last na value. + """ + + if not self.uses_mask: + raise NotImplementedError + + if self.na_position == -1: + raise KeyError("NA") + return self.na_position + cpdef set_item(self, {{dtype}}_t key, Py_ssize_t val): # Used in libjoin + # Caller is responsible for checking for pd.NA cdef: khiter_t k int ret = 0 {{c_type}} ckey + ckey = {{to_c_type}}(key) k = kh_put_{{dtype}}(self.table, ckey, &ret) if kh_exist_{{dtype}}(self.table, k): @@ -459,6 +499,18 @@ cdef class {{name}}HashTable(HashTable): else: raise KeyError(key) + cpdef set_na(self, Py_ssize_t val): + # Caller is responsible for checking for pd.NA + cdef: + khiter_t k + int ret = 0 + {{c_type}} ckey + + if not self.uses_mask: + raise NotImplementedError + + self.na_position = val + {{if dtype == "int64" }} # We only use this for int64, can reduce build size and make .pyi # more accurate by only implementing it for int64 @@ -480,22 +532,36 @@ cdef class {{name}}HashTable(HashTable): {{endif}} @cython.boundscheck(False) - def map_locations(self, const {{dtype}}_t[:] values) -> None: + def map_locations(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> None: # Used in libindex, safe_sort cdef: Py_ssize_t i, n = len(values) int ret = 0 {{c_type}} val khiter_t k + int8_t na_position = self.na_position + + if self.uses_mask and mask is None: + raise NotImplementedError # pragma: no cover with nogil: - for i in range(n): - val= {{to_c_type}}(values[i]) - k = kh_put_{{dtype}}(self.table, val, &ret) - self.table.vals[k] = i + if self.uses_mask: + for i in range(n): + if mask[i]: + na_position = i + else: + val= {{to_c_type}}(values[i]) + k = kh_put_{{dtype}}(self.table, val, &ret) + self.table.vals[k] = i + else: + for i in range(n): + val= {{to_c_type}}(values[i]) + k = kh_put_{{dtype}}(self.table, val, &ret) + self.table.vals[k] = i + self.na_position = na_position @cython.boundscheck(False) - def lookup(self, const {{dtype}}_t[:] values) -> ndarray: + def lookup(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> ndarray: # -> np.ndarray[np.intp] # Used in safe_sort, IndexEngine.get_indexer cdef: @@ -504,15 +570,22 @@ cdef class {{name}}HashTable(HashTable): {{c_type}} val khiter_t k intp_t[::1] locs = np.empty(n, dtype=np.intp) + int8_t na_position = self.na_position + + if self.uses_mask and mask is None: + raise NotImplementedError # pragma: no cover with nogil: for i in range(n): - val = {{to_c_type}}(values[i]) - k = kh_get_{{dtype}}(self.table, val) - if k != self.table.n_buckets: - locs[i] = self.table.vals[k] + if self.uses_mask and mask[i]: + locs[i] = na_position else: - locs[i] = -1 + val = {{to_c_type}}(values[i]) + k = kh_get_{{dtype}}(self.table, val) + if k != self.table.n_buckets: + locs[i] = self.table.vals[k] + else: + locs[i] = -1 return np.asarray(locs) diff --git a/pandas/tests/libs/test_hashtable.py b/pandas/tests/libs/test_hashtable.py index 0cd340ab80897..d9d281a0759da 100644 --- a/pandas/tests/libs/test_hashtable.py +++ b/pandas/tests/libs/test_hashtable.py @@ -1,4 +1,5 @@ from contextlib import contextmanager +import re import struct import tracemalloc from typing import Generator @@ -76,6 +77,49 @@ def test_get_set_contains_len(self, table_type, dtype): assert table.get_item(index + 1) == 41 assert index + 2 not in table + table.set_item(index + 1, 21) + assert index in table + assert index + 1 in table + assert len(table) == 2 + assert table.get_item(index) == 21 + assert table.get_item(index + 1) == 21 + + with pytest.raises(KeyError, match=str(index + 2)): + table.get_item(index + 2) + + def test_get_set_contains_len_mask(self, table_type, dtype): + if table_type == ht.PyObjectHashTable: + pytest.skip("Mask not supported for object") + index = 5 + table = table_type(55, uses_mask=True) + assert len(table) == 0 + assert index not in table + + table.set_item(index, 42) + assert len(table) == 1 + assert index in table + assert table.get_item(index) == 42 + with pytest.raises(KeyError, match="NA"): + table.get_na() + + table.set_item(index + 1, 41) + table.set_na(41) + assert pd.NA in table + assert index in table + assert index + 1 in table + assert len(table) == 3 + assert table.get_item(index) == 42 + assert table.get_item(index + 1) == 41 + assert table.get_na() == 41 + + table.set_na(21) + assert index in table + assert index + 1 in table + assert len(table) == 3 + assert table.get_item(index + 1) == 41 + assert table.get_na() == 21 + assert index + 2 not in table + with pytest.raises(KeyError, match=str(index + 2)): table.get_item(index + 2) @@ -101,6 +145,22 @@ def test_map_locations(self, table_type, dtype, writable): for i in range(N): assert table.get_item(keys[i]) == i + def test_map_locations_mask(self, table_type, dtype, writable): + if table_type == ht.PyObjectHashTable: + pytest.skip("Mask not supported for object") + N = 3 + table = table_type(uses_mask=True) + keys = (np.arange(N) + N).astype(dtype) + keys.flags.writeable = writable + table.map_locations(keys, np.array([False, False, True])) + for i in range(N - 1): + assert table.get_item(keys[i]) == i + + with pytest.raises(KeyError, match=re.escape(str(keys[N - 1]))): + table.get_item(keys[N - 1]) + + assert table.get_na() == 2 + def test_lookup(self, table_type, dtype, writable): N = 3 table = table_type() @@ -123,6 +183,24 @@ def test_lookup_wrong(self, table_type, dtype): result = table.lookup(wrong_keys) assert np.all(result == -1) + def test_lookup_mask(self, table_type, dtype, writable): + if table_type == ht.PyObjectHashTable: + pytest.skip("Mask not supported for object") + N = 3 + table = table_type(uses_mask=True) + keys = (np.arange(N) + N).astype(dtype) + mask = np.array([False, True, False]) + keys.flags.writeable = writable + table.map_locations(keys, mask) + result = table.lookup(keys, mask) + expected = np.arange(N) + tm.assert_numpy_array_equal(result.astype(np.int64), expected.astype(np.int64)) + + result = table.lookup(np.array([1 + N]).astype(dtype), np.array([False])) + tm.assert_numpy_array_equal( + result.astype(np.int64), np.array([-1], dtype=np.int64) + ) + def test_unique(self, table_type, dtype, writable): if dtype in (np.int8, np.uint8): N = 88 diff --git a/scripts/run_stubtest.py b/scripts/run_stubtest.py index d90f8575234e8..db7a327f231b5 100644 --- a/scripts/run_stubtest.py +++ b/scripts/run_stubtest.py @@ -36,10 +36,12 @@ "pandas._libs.hashtable.HashTable.factorize", "pandas._libs.hashtable.HashTable.get_item", "pandas._libs.hashtable.HashTable.get_labels", + "pandas._libs.hashtable.HashTable.get_na", "pandas._libs.hashtable.HashTable.get_state", "pandas._libs.hashtable.HashTable.lookup", "pandas._libs.hashtable.HashTable.map_locations", "pandas._libs.hashtable.HashTable.set_item", + "pandas._libs.hashtable.HashTable.set_na", "pandas._libs.hashtable.HashTable.sizeof", "pandas._libs.hashtable.HashTable.unique", # stubtest might be too sensitive