Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PERF: optimize algos.take for repeated calls #39692

Merged
merged 24 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4512f9c
PERF: optimize algos.take for repeated calls
jorisvandenbossche Feb 9, 2021
36c3ed2
fix nd check + fix cache differentiation of int / bool
jorisvandenbossche Feb 9, 2021
6d52932
fix non-scalar fill_value case
jorisvandenbossche Feb 9, 2021
ded773a
fix mypy
jorisvandenbossche Feb 9, 2021
2ee2543
try fix mypy
jorisvandenbossche Feb 9, 2021
f489ba5
fix annotation
jorisvandenbossche Feb 10, 2021
96305c5
improve docstrings
jorisvandenbossche Feb 10, 2021
480d2b4
Merge remote-tracking branch 'upstream/master' into am-perf-take
jorisvandenbossche Feb 10, 2021
c70ac4d
faster EA check
jorisvandenbossche Feb 10, 2021
9fba887
Merge remote-tracking branch 'upstream/master' into am-perf-take
jorisvandenbossche Feb 11, 2021
5273cd5
rename take_1d_array to take_1d
jorisvandenbossche Feb 11, 2021
d3dd4e4
add comment about being useful for array manager
jorisvandenbossche Feb 11, 2021
288c6f2
Merge remote-tracking branch 'upstream/master' into am-perf-take
jorisvandenbossche Feb 15, 2021
06a3901
Merge remote-tracking branch 'upstream/master' into am-perf-take
jorisvandenbossche Mar 2, 2021
ca30487
use take_nd for now
jorisvandenbossche Mar 2, 2021
bf598a7
Merge remote-tracking branch 'upstream/master' into am-perf-take
jorisvandenbossche Mar 2, 2021
05b6b87
move caching of maybe_promote to cast.py
jorisvandenbossche Mar 2, 2021
2284813
move type comment
jorisvandenbossche Mar 2, 2021
4861fdb
Merge remote-tracking branch 'upstream/master' into am-perf-take
jorisvandenbossche Mar 2, 2021
a41ee6b
typo
jorisvandenbossche Mar 2, 2021
b52e1ec
Merge remote-tracking branch 'upstream/master' into am-perf-take
jorisvandenbossche Mar 4, 2021
76371cf
ensure deprecation warning is always raised
jorisvandenbossche Mar 4, 2021
2faf70b
single underscore
jorisvandenbossche Mar 4, 2021
1c19732
Merge remote-tracking branch 'upstream/master' into am-perf-take
jorisvandenbossche Mar 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 107 additions & 29 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
from __future__ import annotations

import functools
import operator
from textwrap import dedent
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union, cast
Expand Down Expand Up @@ -1534,40 +1535,79 @@ def _take_nd_object(arr, indexer, out, axis: int, fill_value, mask_info):
}


@functools.lru_cache(maxsize=128)
def __get_take_nd_function_cached(ndim, arr_dtype, out_dtype, axis):
"""
Part of _get_take_nd_function below that doesn't need `mask_info` and thus
can be cached (mask_info potentially contains a numpy ndarray which is not
hashable and thus cannot be used as argument for cached function).
"""
tup = (arr_dtype.name, out_dtype.name)
if ndim == 1:
func = _take_1d_dict.get(tup, None)
elif ndim == 2:
if axis == 0:
func = _take_2d_axis0_dict.get(tup, None)
jbrockmendel marked this conversation as resolved.
Show resolved Hide resolved
else:
func = _take_2d_axis1_dict.get(tup, None)
if func is not None:
return func

tup = (out_dtype.name, out_dtype.name)
if ndim == 1:
func = _take_1d_dict.get(tup, None)
elif ndim == 2:
if axis == 0:
func = _take_2d_axis0_dict.get(tup, None)
else:
func = _take_2d_axis1_dict.get(tup, None)
if func is not None:
func = _convert_wrapper(func, out_dtype)
return func

return None


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the caching not on this function? having too many levels of indirection is -1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will clarify the comment above, the mask_info argument to this function is not hashable

def _get_take_nd_function(
ndim: int, arr_dtype, out_dtype, axis: int = 0, mask_info=None
):
"""
Get the appropriate "take" implementation for the given dimension, axis
and dtypes.
"""
func = None
if ndim <= 2:
jbrockmendel marked this conversation as resolved.
Show resolved Hide resolved
tup = (arr_dtype.name, out_dtype.name)
if ndim == 1:
func = _take_1d_dict.get(tup, None)
elif ndim == 2:
if axis == 0:
func = _take_2d_axis0_dict.get(tup, None)
else:
func = _take_2d_axis1_dict.get(tup, None)
if func is not None:
return func

tup = (out_dtype.name, out_dtype.name)
if ndim == 1:
func = _take_1d_dict.get(tup, None)
elif ndim == 2:
if axis == 0:
func = _take_2d_axis0_dict.get(tup, None)
else:
func = _take_2d_axis1_dict.get(tup, None)
if func is not None:
func = _convert_wrapper(func, out_dtype)
return func
# for this part we don't need `mask_info` -> use the cached algo lookup
func = __get_take_nd_function_cached(ndim, arr_dtype, out_dtype, axis)

if func is None:

def func(arr, indexer, out, fill_value=np.nan):
indexer = ensure_int64(indexer)
_take_nd_object(
arr, indexer, out, axis=axis, fill_value=fill_value, mask_info=mask_info
)

return func

def func2(arr, indexer, out, fill_value=np.nan):
indexer = ensure_int64(indexer)
_take_nd_object(
arr, indexer, out, axis=axis, fill_value=fill_value, mask_info=mask_info
)

return func2
@functools.lru_cache(maxsize=128)
jbrockmendel marked this conversation as resolved.
Show resolved Hide resolved
def _maybe_promote_cached(dtype, fill_value, fill_value_type):
# also use fill_value_type as (unused) argument to use this in the cache
# lookup -> differentiate 1 and True
return maybe_promote(dtype, fill_value)


def _maybe_promote(dtype, fill_value):
jbrockmendel marked this conversation as resolved.
Show resolved Hide resolved
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very strange to do. can you simply change all of the call of _maybe_promote to _maybe_promote_cached that need it (e.g. in the code and not the tests).

putting a try/except inside here is reversing the paradigm and not good.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is not simply possible. We don't know in advance if the fill_value is going to be hashable or not. So that's the reason the fallback is needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure but the try/except needs to be in the cached method, NOT here. IOW you are now exposing 2 api's, we need to have exactly one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But try/except is because of non-hashable fill_values, which thus cannot be inside the cached method, that's the whole reason I added the try/except in the first place.

I am not exposing two different APIs. These are internal helper methods, and _maybe_promote_cached is only used for _maybe_promote, and it is _maybe_promote that is then used internally in the take implementation.
I can add a comment to _maybe_promote_cached that this is only the cached part of _maybe_promote and not to be used elsewhere.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just simply ban non-hashables from _maybe_promote?

# error: Argument 3 to "__call__" of "_lru_cache_wrapper" has incompatible type
# "Type[Any]"; expected "Hashable" [arg-type]
return _maybe_promote_cached(
dtype, fill_value, type(fill_value)
) # type: ignore[arg-type]
except TypeError:
# if fill_value is not hashable (required for caching)
return maybe_promote(dtype, fill_value)


def take(arr, indices, axis: int = 0, allow_fill: bool = False, fill_value=None):
Expand Down Expand Up @@ -1677,7 +1717,7 @@ def _take_preprocess_indexer_and_fill_value(
else:
# check for promotion based on types only (do this first because
# it's faster than computing a mask)
dtype, fill_value = maybe_promote(arr.dtype, fill_value)
dtype, fill_value = _maybe_promote(arr.dtype, fill_value)
if dtype != arr.dtype and (out is None or out.dtype != dtype):
# check if promotion is actually required based on indexer
mask = indexer == -1
Expand Down Expand Up @@ -1786,6 +1826,44 @@ def take_nd(
take_1d = take_nd


def take_1d_array(
jorisvandenbossche marked this conversation as resolved.
Show resolved Hide resolved
arr: ArrayLike,
indexer: np.ndarray,
out=None,
fill_value=lib.no_default,
allow_fill: bool = True,
):
"""
Specialized version for 1D arrays. Differences compared to take_nd/take_1d:

- Assumes input (arr, indexer) has already been converted to numpy array / EA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add these assertions

- Only works for 1D arrays

To ensure the lowest possible overhead.
"""
if fill_value is lib.no_default:
fill_value = na_value_for_dtype(arr.dtype, compat=False)

if isinstance(arr, ABCExtensionArray):
jbrockmendel marked this conversation as resolved.
Show resolved Hide resolved
# Check for EA to catch DatetimeArray, TimedeltaArray
return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill)

indexer, dtype, fill_value, mask_info = _take_preprocess_indexer_and_fill_value(
arr, indexer, 0, out, fill_value, allow_fill
)

# at this point, it's guaranteed that dtype can hold both the arr values
# and the fill_value
out = np.empty(indexer.shape, dtype=dtype)

func = _get_take_nd_function(
arr.ndim, arr.dtype, out.dtype, axis=0, mask_info=mask_info
)
func(arr, indexer, out, fill_value)

return out


def take_2d_multi(arr, indexer, fill_value=np.nan):
"""
Specialized Cython take which sets NaN values in one pass.
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/internals/array_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ def unstack(self, unstacker, fill_value) -> ArrayManager:
new_arrays = []
for arr in self.arrays:
for i in range(unstacker.full_shape[1]):
new_arr = algos.take(
new_arr = algos.take_1d_array(
arr, new_indexer2D[:, i], allow_fill=True, fill_value=fill_value
)
new_arrays.append(new_arr)
Expand Down
12 changes: 12 additions & 0 deletions pandas/tests/test_take.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,18 @@ def test_take_axis_1(self):
with pytest.raises(IndexError, match="indices are out-of-bounds"):
algos.take(arr, [0, 3], axis=1, allow_fill=True, fill_value=0)

def test_take_non_hashable_fill_value(self):
arr = np.array([1, 2, 3])
indexer = np.array([1, -1])
with pytest.raises(ValueError, match="fill_value must be a scalar"):
algos.take(arr, indexer, allow_fill=True, fill_value=[1])

# with object dtype it is allowed
arr = np.array([1, 2, 3], dtype=object)
result = algos.take(arr, indexer, allow_fill=True, fill_value=[1])
expected = np.array([2, [1]], dtype=object)
tm.assert_numpy_array_equal(result, expected)


class TestExtensionTake:
# The take method found in pd.api.extensions
Expand Down