Skip to content

Commit

Permalink
ENH: mask support for hastable functions for indexing (#48396)
Browse files Browse the repository at this point in the history
* ENH: mask support for hastable functions for indexing

* Fix mypy

* Adjust test

* Add comment

* Add docstring

* Refactor into own functions

* Fix typing
  • Loading branch information
phofl authored Oct 20, 2022
1 parent 0e88f4f commit f271445
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 16 deletions.
48 changes: 48 additions & 0 deletions pandas/_libs/hashtable.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -41,75 +41,123 @@ cdef class HashTable:

cdef class UInt64HashTable(HashTable):
cdef kh_uint64_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, uint64_t val)
cpdef set_item(self, uint64_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class Int64HashTable(HashTable):
cdef kh_int64_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, int64_t val)
cpdef set_item(self, int64_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class UInt32HashTable(HashTable):
cdef kh_uint32_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, uint32_t val)
cpdef set_item(self, uint32_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class Int32HashTable(HashTable):
cdef kh_int32_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, int32_t val)
cpdef set_item(self, int32_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class UInt16HashTable(HashTable):
cdef kh_uint16_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, uint16_t val)
cpdef set_item(self, uint16_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class Int16HashTable(HashTable):
cdef kh_int16_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, int16_t val)
cpdef set_item(self, int16_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class UInt8HashTable(HashTable):
cdef kh_uint8_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, uint8_t val)
cpdef set_item(self, uint8_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class Int8HashTable(HashTable):
cdef kh_int8_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, int8_t val)
cpdef set_item(self, int8_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class Float64HashTable(HashTable):
cdef kh_float64_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, float64_t val)
cpdef set_item(self, float64_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class Float32HashTable(HashTable):
cdef kh_float32_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, float32_t val)
cpdef set_item(self, float32_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class Complex64HashTable(HashTable):
cdef kh_complex64_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, complex64_t val)
cpdef set_item(self, complex64_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class Complex128HashTable(HashTable):
cdef kh_complex128_t *table
cdef int64_t na_position
cdef bint uses_mask

cpdef get_item(self, complex128_t val)
cpdef set_item(self, complex128_t key, Py_ssize_t val)
cpdef get_na(self)
cpdef set_na(self, Py_ssize_t val)

cdef class PyObjectHashTable(HashTable):
cdef kh_pymap_t *table
Expand Down
8 changes: 5 additions & 3 deletions pandas/_libs/hashtable.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,18 @@ class ObjectVector:

class HashTable:
# NB: The base HashTable class does _not_ actually have these methods;
# we are putting the here for the sake of mypy to avoid
# we are putting them here for the sake of mypy to avoid
# reproducing them in each subclass below.
def __init__(self, size_hint: int = ...) -> None: ...
def __init__(self, size_hint: int = ..., uses_mask: bool = ...) -> None: ...
def __len__(self) -> int: ...
def __contains__(self, key: Hashable) -> bool: ...
def sizeof(self, deep: bool = ...) -> int: ...
def get_state(self) -> dict[str, int]: ...
# TODO: `item` type is subclass-specific
def get_item(self, item): ... # TODO: return type?
def set_item(self, item) -> None: ...
def set_item(self, item, val) -> None: ...
def get_na(self): ... # TODO: return type?
def set_na(self, val) -> None: ...
def map_locations(
self,
values: np.ndarray, # np.ndarray[subclass-specific]
Expand Down
99 changes: 86 additions & 13 deletions pandas/_libs/hashtable_class_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -396,23 +396,32 @@ dtypes = [('Complex128', 'complex128', 'khcomplex128_t', 'to_khcomplex128_t'),

cdef class {{name}}HashTable(HashTable):

def __cinit__(self, int64_t size_hint=1):
def __cinit__(self, int64_t size_hint=1, bint uses_mask=False):
self.table = kh_init_{{dtype}}()
size_hint = min(kh_needed_n_buckets(size_hint), SIZE_HINT_LIMIT)
kh_resize_{{dtype}}(self.table, size_hint)

self.uses_mask = uses_mask
self.na_position = -1

def __len__(self) -> int:
return self.table.size
return self.table.size + (0 if self.na_position == -1 else 1)

def __dealloc__(self):
if self.table is not NULL:
kh_destroy_{{dtype}}(self.table)
self.table = NULL

def __contains__(self, object key) -> bool:
# The caller is responsible to check for compatible NA values in case
# of masked arrays.
cdef:
khiter_t k
{{c_type}} ckey

if self.uses_mask and checknull(key):
return -1 != self.na_position

ckey = {{to_c_type}}(key)
k = kh_get_{{dtype}}(self.table, ckey)
return k != self.table.n_buckets
Expand All @@ -435,30 +444,73 @@ cdef class {{name}}HashTable(HashTable):
}

cpdef get_item(self, {{dtype}}_t val):
"""Extracts the position of val from the hashtable.

Parameters
----------
val : Scalar
The value that is looked up in the hashtable

Returns
-------
The position of the requested integer.
"""

# Used in core.sorting, IndexEngine.get_loc
# Caller is responsible for checking for pd.NA
cdef:
khiter_t k
{{c_type}} cval

cval = {{to_c_type}}(val)
k = kh_get_{{dtype}}(self.table, cval)
if k != self.table.n_buckets:
return self.table.vals[k]
else:
raise KeyError(val)

cpdef get_na(self):
"""Extracts the position of na_value from the hashtable.

Returns
-------
The position of the last na value.
"""

if not self.uses_mask:
raise NotImplementedError

if self.na_position == -1:
raise KeyError("NA")
return self.na_position

cpdef set_item(self, {{dtype}}_t key, Py_ssize_t val):
# Used in libjoin
# Caller is responsible for checking for pd.NA
cdef:
khiter_t k
int ret = 0
{{c_type}} ckey

ckey = {{to_c_type}}(key)
k = kh_put_{{dtype}}(self.table, ckey, &ret)
if kh_exist_{{dtype}}(self.table, k):
self.table.vals[k] = val
else:
raise KeyError(key)

cpdef set_na(self, Py_ssize_t val):
# Caller is responsible for checking for pd.NA
cdef:
khiter_t k
int ret = 0
{{c_type}} ckey

if not self.uses_mask:
raise NotImplementedError

self.na_position = val

{{if dtype == "int64" }}
# We only use this for int64, can reduce build size and make .pyi
# more accurate by only implementing it for int64
Expand All @@ -480,22 +532,36 @@ cdef class {{name}}HashTable(HashTable):
{{endif}}

@cython.boundscheck(False)
def map_locations(self, const {{dtype}}_t[:] values) -> None:
def map_locations(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> None:
# Used in libindex, safe_sort
cdef:
Py_ssize_t i, n = len(values)
int ret = 0
{{c_type}} val
khiter_t k
int8_t na_position = self.na_position

if self.uses_mask and mask is None:
raise NotImplementedError # pragma: no cover

with nogil:
for i in range(n):
val= {{to_c_type}}(values[i])
k = kh_put_{{dtype}}(self.table, val, &ret)
self.table.vals[k] = i
if self.uses_mask:
for i in range(n):
if mask[i]:
na_position = i
else:
val= {{to_c_type}}(values[i])
k = kh_put_{{dtype}}(self.table, val, &ret)
self.table.vals[k] = i
else:
for i in range(n):
val= {{to_c_type}}(values[i])
k = kh_put_{{dtype}}(self.table, val, &ret)
self.table.vals[k] = i
self.na_position = na_position

@cython.boundscheck(False)
def lookup(self, const {{dtype}}_t[:] values) -> ndarray:
def lookup(self, const {{dtype}}_t[:] values, const uint8_t[:] mask = None) -> ndarray:
# -> np.ndarray[np.intp]
# Used in safe_sort, IndexEngine.get_indexer
cdef:
Expand All @@ -504,15 +570,22 @@ cdef class {{name}}HashTable(HashTable):
{{c_type}} val
khiter_t k
intp_t[::1] locs = np.empty(n, dtype=np.intp)
int8_t na_position = self.na_position

if self.uses_mask and mask is None:
raise NotImplementedError # pragma: no cover

with nogil:
for i in range(n):
val = {{to_c_type}}(values[i])
k = kh_get_{{dtype}}(self.table, val)
if k != self.table.n_buckets:
locs[i] = self.table.vals[k]
if self.uses_mask and mask[i]:
locs[i] = na_position
else:
locs[i] = -1
val = {{to_c_type}}(values[i])
k = kh_get_{{dtype}}(self.table, val)
if k != self.table.n_buckets:
locs[i] = self.table.vals[k]
else:
locs[i] = -1

return np.asarray(locs)

Expand Down
Loading

0 comments on commit f271445

Please sign in to comment.