From 293b0fe72da2d46db0ddbea7813c32c144ecc37b Mon Sep 17 00:00:00 2001 From: Amanda Potts Date: Mon, 4 Nov 2024 16:43:59 -0500 Subject: [PATCH] Closes #3884: Remove _squeeze function --- arkouda/pdarrayclass.py | 54 ++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/arkouda/pdarrayclass.py b/arkouda/pdarrayclass.py index ff189c62a3..188bdaffda 100644 --- a/arkouda/pdarrayclass.py +++ b/arkouda/pdarrayclass.py @@ -140,24 +140,6 @@ def _create_scalar_array(value): ) -def _squeeze(array: pdarray, degen_axes: List[int]): - """ - Remove degenerate axes from a pdarray - - Requires the ManipulationMsg server module - """ - return create_pdarray( - generic_msg( - cmd=f"squeeze<{array.dtype},{array.ndim},{array.ndim-len(degen_axes)}>", - args={ - "name": array, - "nAxes": len(degen_axes), - "axes": degen_axes, - }, - ) - ) - - def _slice_index(array: pdarray, starts: List[int], stops: List[int], strides: List[int]): """ Slice a pdarray with a set of start, stop and stride values @@ -951,17 +933,20 @@ def __getitem__(self, key): }, ) ) + from arkouda.numpy import squeeze # remove any degenerate dimensions - ret_array = _squeeze(temp2, degen_axes) + ret_array = squeeze(temp2, tuple(degen_axes)) else: # all slice or scalar indices: use slice indexing only maybe_degen_arr = _slice_index(self, starts, stops, strides) if len(scalar_axes) > 0: + from arkouda.numpy import squeeze + # reduce the array rank if there are any scalar indices - ret_array = _squeeze(maybe_degen_arr, scalar_axes) + ret_array = squeeze(maybe_degen_arr, tuple(scalar_axes)) else: ret_array = maybe_degen_arr @@ -2738,19 +2723,32 @@ def _reduces_to_single_value(axis, ndim) -> bool: # helper function for sum, min, max, prod -def _comon_reduction( - pda: pdarray, axis: Optional[Union[int, Tuple[int, ...]]], kind: str -): +def _comon_reduction(pda: pdarray, axis: Optional[Union[int, Tuple[int, ...]]], kind: str): if kind not in ["sum", "min", "max", "prod"]: raise ValueError(f"Unsupported reduction type: {kind}") - axis_ = [] if axis is None else ([axis,] if isinstance(axis, int) else list(axis)) + axis_ = ( + [] + if axis is None + else ( + [ + axis, + ] + if isinstance(axis, int) + else list(axis) + ) + ) if _reduces_to_single_value(axis_, pda.ndim): - return parse_single_value(cast(str, generic_msg( - cmd=f"{kind}All<{pda.dtype.name},{pda.ndim}>", - args={"x": pda, "skipNan": False}, - ))) + return parse_single_value( + cast( + str, + generic_msg( + cmd=f"{kind}All<{pda.dtype.name},{pda.ndim}>", + args={"x": pda, "skipNan": False}, + ), + ) + ) else: return create_pdarray( generic_msg(