Skip to content

Commit

Permalink
Closes #3886: refactor idxReductionMsg
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Nov 8, 2024
1 parent f570fe0 commit 4ab8a9c
Show file tree
Hide file tree
Showing 8 changed files with 357 additions and 226 deletions.
58 changes: 4 additions & 54 deletions arkouda/array_api/searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import arkouda as ak
from arkouda.client import generic_msg
from arkouda.numpy import cast as akcast
from arkouda.pdarrayclass import create_pdarray, create_pdarrays, parse_single_value
from arkouda.pdarraycreation import scalar_array
from arkouda.pdarrayclass import create_pdarray, create_pdarrays

from ._dtypes import _real_floating_dtypes, _real_numeric_dtypes
from .array_object import Array
from .manipulation_functions import broadcast_arrays, reshape, squeeze
from .manipulation_functions import broadcast_arrays


def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array:
Expand All @@ -31,31 +30,7 @@ def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in argmax")

if x.ndim > 1 and axis is None:
# must flatten ND arrays to 1D without an axis argument
x_op = reshape(x, shape=(-1,))
else:
x_op = x

resp = generic_msg(
cmd=f"reduce->idx{x_op.ndim}D",
args={
"x": x_op._array,
"op": "argmax",
"hasAxis": axis is not None,
"axis": axis if axis is not None else 0,
},
)

if axis is None:
return Array._new(scalar_array(parse_single_value(resp)))
else:
arr = Array._new(create_pdarray(resp))

if keepdims:
return arr
else:
return squeeze(arr, axis)
return Array._new(ak.argmax(x._array, axis=axis, keepdims=keepdims))


def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array:
Expand All @@ -74,32 +49,7 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
"""
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in argmax")

if x.ndim > 1 and axis is None:
# must flatten ND arrays to 1D without an axis argument
x_op = reshape(x, shape=(-1,))
else:
x_op = x

resp = generic_msg(
cmd=f"reduce->idx{x_op.ndim}D",
args={
"x": x_op._array,
"op": "argmin",
"hasAxis": axis is not None,
"axis": axis if axis is not None else 0,
},
)

if axis is None:
return Array._new(scalar_array(parse_single_value(resp)))
else:
arr = Array._new(create_pdarray(resp))

if keepdims:
return arr
else:
return squeeze(arr, axis)
return Array._new(ak.argmin(x._array, axis=axis, keepdims=keepdims))


def nonzero(x: Array, /) -> Tuple[Array, ...]:
Expand Down
1 change: 1 addition & 0 deletions arkouda/numpy/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
__all__ = [
"_datatype_check",
"ARKOUDA_SUPPORTED_DTYPES",
"ARKOUDA_SUPPORTED_INTS",
"DType",
"DTypeObjects",
"DTypes",
Expand Down
Loading

0 comments on commit 4ab8a9c

Please sign in to comment.