Skip to content

Commit

Permalink
REF: de-duplicate libjoin (pandas-dev#46256)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel committed Mar 13, 2022
1 parent c0b467f commit 6334376
Showing 1 changed file with 89 additions and 149 deletions.
238 changes: 89 additions & 149 deletions pandas/_libs/join.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,13 @@ def left_outer_join(const intp_t[:] left, const intp_t[:] right,
with nogil:
# First pass, determine size of result set, do not use the NA group
for i in range(1, max_groups + 1):
if right_count[i] > 0:
count += left_count[i] * right_count[i]
lc = left_count[i]
rc = right_count[i]

if rc > 0:
count += lc * rc
else:
count += left_count[i]
count += lc

left_indexer = np.empty(count, dtype=np.intp)
right_indexer = np.empty(count, dtype=np.intp)
Expand Down Expand Up @@ -679,7 +682,8 @@ def asof_join_backward_on_X_by_Y(numeric_t[:] left_values,
by_t[:] left_by_values,
by_t[:] right_by_values,
bint allow_exact_matches=True,
tolerance=None):
tolerance=None,
bint use_hashtable=True):

cdef:
Py_ssize_t left_pos, right_pos, left_size, right_size, found_right_pos
Expand All @@ -701,12 +705,13 @@ def asof_join_backward_on_X_by_Y(numeric_t[:] left_values,
left_indexer = np.empty(left_size, dtype=np.intp)
right_indexer = np.empty(left_size, dtype=np.intp)

if by_t is object:
hash_table = PyObjectHashTable(right_size)
elif by_t is int64_t:
hash_table = Int64HashTable(right_size)
elif by_t is uint64_t:
hash_table = UInt64HashTable(right_size)
if use_hashtable:
if by_t is object:
hash_table = PyObjectHashTable(right_size)
elif by_t is int64_t:
hash_table = Int64HashTable(right_size)
elif by_t is uint64_t:
hash_table = UInt64HashTable(right_size)

right_pos = 0
for left_pos in range(left_size):
Expand All @@ -718,19 +723,25 @@ def asof_join_backward_on_X_by_Y(numeric_t[:] left_values,
if allow_exact_matches:
while (right_pos < right_size and
right_values[right_pos] <= left_values[left_pos]):
hash_table.set_item(right_by_values[right_pos], right_pos)
if use_hashtable:
hash_table.set_item(right_by_values[right_pos], right_pos)
right_pos += 1
else:
while (right_pos < right_size and
right_values[right_pos] < left_values[left_pos]):
hash_table.set_item(right_by_values[right_pos], right_pos)
if use_hashtable:
hash_table.set_item(right_by_values[right_pos], right_pos)
right_pos += 1
right_pos -= 1

# save positions as the desired index
by_value = left_by_values[left_pos]
found_right_pos = (hash_table.get_item(by_value)
if by_value in hash_table else -1)
if use_hashtable:
by_value = left_by_values[left_pos]
found_right_pos = (hash_table.get_item(by_value)
if by_value in hash_table else -1)
else:
found_right_pos = right_pos

left_indexer[left_pos] = left_pos
right_indexer[left_pos] = found_right_pos

Expand All @@ -748,7 +759,8 @@ def asof_join_forward_on_X_by_Y(numeric_t[:] left_values,
by_t[:] left_by_values,
by_t[:] right_by_values,
bint allow_exact_matches=1,
tolerance=None):
tolerance=None,
bint use_hashtable=True):

cdef:
Py_ssize_t left_pos, right_pos, left_size, right_size, found_right_pos
Expand All @@ -770,12 +782,13 @@ def asof_join_forward_on_X_by_Y(numeric_t[:] left_values,
left_indexer = np.empty(left_size, dtype=np.intp)
right_indexer = np.empty(left_size, dtype=np.intp)

if by_t is object:
hash_table = PyObjectHashTable(right_size)
elif by_t is int64_t:
hash_table = Int64HashTable(right_size)
elif by_t is uint64_t:
hash_table = UInt64HashTable(right_size)
if use_hashtable:
if by_t is object:
hash_table = PyObjectHashTable(right_size)
elif by_t is int64_t:
hash_table = Int64HashTable(right_size)
elif by_t is uint64_t:
hash_table = UInt64HashTable(right_size)

right_pos = right_size - 1
for left_pos in range(left_size - 1, -1, -1):
Expand All @@ -787,19 +800,26 @@ def asof_join_forward_on_X_by_Y(numeric_t[:] left_values,
if allow_exact_matches:
while (right_pos >= 0 and
right_values[right_pos] >= left_values[left_pos]):
hash_table.set_item(right_by_values[right_pos], right_pos)
if use_hashtable:
hash_table.set_item(right_by_values[right_pos], right_pos)
right_pos -= 1
else:
while (right_pos >= 0 and
right_values[right_pos] > left_values[left_pos]):
hash_table.set_item(right_by_values[right_pos], right_pos)
if use_hashtable:
hash_table.set_item(right_by_values[right_pos], right_pos)
right_pos -= 1
right_pos += 1

# save positions as the desired index
by_value = left_by_values[left_pos]
found_right_pos = (hash_table.get_item(by_value)
if by_value in hash_table else -1)
if use_hashtable:
by_value = left_by_values[left_pos]
found_right_pos = (hash_table.get_item(by_value)
if by_value in hash_table else -1)
else:
found_right_pos = (right_pos
if right_pos != right_size else -1)

left_indexer[left_pos] = left_pos
right_indexer[left_pos] = found_right_pos

Expand All @@ -820,15 +840,7 @@ def asof_join_nearest_on_X_by_Y(numeric_t[:] left_values,
tolerance=None):

cdef:
Py_ssize_t left_size, right_size, i
ndarray[intp_t] left_indexer, right_indexer, bli, bri, fli, fri
numeric_t bdiff, fdiff

left_size = len(left_values)
right_size = len(right_values)

left_indexer = np.empty(left_size, dtype=np.intp)
right_indexer = np.empty(left_size, dtype=np.intp)
ndarray[intp_t] bli, bri, fli, fri

# search both forward and backward
bli, bri = asof_join_backward_on_X_by_Y(
Expand All @@ -848,6 +860,27 @@ def asof_join_nearest_on_X_by_Y(numeric_t[:] left_values,
tolerance,
)

return _choose_smaller_timestamp(left_values, right_values, bli, bri, fli, fri)


cdef _choose_smaller_timestamp(
numeric_t[:] left_values,
numeric_t[:] right_values,
ndarray[intp_t] bli,
ndarray[intp_t] bri,
ndarray[intp_t] fli,
ndarray[intp_t] fri,
):
cdef:
ndarray[intp_t] left_indexer, right_indexer
Py_ssize_t left_size, i
numeric_t bdiff, fdiff

left_size = len(left_values)

left_indexer = np.empty(left_size, dtype=np.intp)
right_indexer = np.empty(left_size, dtype=np.intp)

for i in range(len(bri)):
# choose timestamp from right with smaller difference
if bri[i] != -1 and fri[i] != -1:
Expand All @@ -870,106 +903,30 @@ def asof_join_backward(numeric_t[:] left_values,
bint allow_exact_matches=True,
tolerance=None):

cdef:
Py_ssize_t left_pos, right_pos, left_size, right_size
ndarray[intp_t] left_indexer, right_indexer
bint has_tolerance = False
numeric_t tolerance_ = 0
numeric_t diff = 0

# if we are using tolerance, set our objects
if tolerance is not None:
has_tolerance = True
tolerance_ = tolerance

left_size = len(left_values)
right_size = len(right_values)

left_indexer = np.empty(left_size, dtype=np.intp)
right_indexer = np.empty(left_size, dtype=np.intp)

right_pos = 0
for left_pos in range(left_size):
# restart right_pos if it went negative in a previous iteration
if right_pos < 0:
right_pos = 0

# find last position in right whose value is less than left's
if allow_exact_matches:
while (right_pos < right_size and
right_values[right_pos] <= left_values[left_pos]):
right_pos += 1
else:
while (right_pos < right_size and
right_values[right_pos] < left_values[left_pos]):
right_pos += 1
right_pos -= 1

# save positions as the desired index
left_indexer[left_pos] = left_pos
right_indexer[left_pos] = right_pos

# if needed, verify that tolerance is met
if has_tolerance and right_pos != -1:
diff = left_values[left_pos] - right_values[right_pos]
if diff > tolerance_:
right_indexer[left_pos] = -1

return left_indexer, right_indexer
return asof_join_backward_on_X_by_Y(
left_values,
right_values,
None,
None,
allow_exact_matches=allow_exact_matches,
tolerance=tolerance,
use_hashtable=False,
)


def asof_join_forward(numeric_t[:] left_values,
numeric_t[:] right_values,
bint allow_exact_matches=True,
tolerance=None):

cdef:
Py_ssize_t left_pos, right_pos, left_size, right_size
ndarray[intp_t] left_indexer, right_indexer
bint has_tolerance = False
numeric_t tolerance_ = 0
numeric_t diff = 0

# if we are using tolerance, set our objects
if tolerance is not None:
has_tolerance = True
tolerance_ = tolerance

left_size = len(left_values)
right_size = len(right_values)

left_indexer = np.empty(left_size, dtype=np.intp)
right_indexer = np.empty(left_size, dtype=np.intp)

right_pos = right_size - 1
for left_pos in range(left_size - 1, -1, -1):
# restart right_pos if it went over in a previous iteration
if right_pos == right_size:
right_pos = right_size - 1

# find first position in right whose value is greater than left's
if allow_exact_matches:
while (right_pos >= 0 and
right_values[right_pos] >= left_values[left_pos]):
right_pos -= 1
else:
while (right_pos >= 0 and
right_values[right_pos] > left_values[left_pos]):
right_pos -= 1
right_pos += 1

# save positions as the desired index
left_indexer[left_pos] = left_pos
right_indexer[left_pos] = (right_pos
if right_pos != right_size else -1)

# if needed, verify that tolerance is met
if has_tolerance and right_pos != right_size:
diff = right_values[right_pos] - left_values[left_pos]
if diff > tolerance_:
right_indexer[left_pos] = -1

return left_indexer, right_indexer
return asof_join_forward_on_X_by_Y(
left_values,
right_values,
None,
None,
allow_exact_matches=allow_exact_matches,
tolerance=tolerance,
use_hashtable=False,
)


def asof_join_nearest(numeric_t[:] left_values,
Expand All @@ -978,29 +935,12 @@ def asof_join_nearest(numeric_t[:] left_values,
tolerance=None):

cdef:
Py_ssize_t left_size, i
ndarray[intp_t] left_indexer, right_indexer, bli, bri, fli, fri
numeric_t bdiff, fdiff

left_size = len(left_values)

left_indexer = np.empty(left_size, dtype=np.intp)
right_indexer = np.empty(left_size, dtype=np.intp)
ndarray[intp_t] bli, bri, fli, fri

# search both forward and backward
bli, bri = asof_join_backward(left_values, right_values,
allow_exact_matches, tolerance)
fli, fri = asof_join_forward(left_values, right_values,
allow_exact_matches, tolerance)

for i in range(len(bri)):
# choose timestamp from right with smaller difference
if bri[i] != -1 and fri[i] != -1:
bdiff = left_values[bli[i]] - right_values[bri[i]]
fdiff = right_values[fri[i]] - left_values[fli[i]]
right_indexer[i] = bri[i] if bdiff <= fdiff else fri[i]
else:
right_indexer[i] = bri[i] if bri[i] != -1 else fri[i]
left_indexer[i] = bli[i]

return left_indexer, right_indexer
return _choose_smaller_timestamp(left_values, right_values, bli, bri, fli, fri)

0 comments on commit 6334376

Please sign in to comment.