Skip to content

Commit

Permalink
Closes #3884: Remove _squeeze function (#3885)
Browse files Browse the repository at this point in the history
Co-authored-by: Amanda Potts <[email protected]>
  • Loading branch information
ajpotts and ajpotts authored Nov 5, 2024
1 parent 2bb26b1 commit 5289b58
Showing 1 changed file with 26 additions and 28 deletions.
54 changes: 26 additions & 28 deletions arkouda/pdarrayclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,24 +140,6 @@ def _create_scalar_array(value):
)


def _squeeze(array: pdarray, degen_axes: List[int]):
"""
Remove degenerate axes from a pdarray
Requires the ManipulationMsg server module
"""
return create_pdarray(
generic_msg(
cmd=f"squeeze<{array.dtype},{array.ndim},{array.ndim-len(degen_axes)}>",
args={
"name": array,
"nAxes": len(degen_axes),
"axes": degen_axes,
},
)
)


def _slice_index(array: pdarray, starts: List[int], stops: List[int], strides: List[int]):
"""
Slice a pdarray with a set of start, stop and stride values
Expand Down Expand Up @@ -951,17 +933,20 @@ def __getitem__(self, key):
},
)
)
from arkouda.numpy import squeeze

# remove any degenerate dimensions
ret_array = _squeeze(temp2, degen_axes)
ret_array = squeeze(temp2, tuple(degen_axes))

else:
# all slice or scalar indices: use slice indexing only
maybe_degen_arr = _slice_index(self, starts, stops, strides)

if len(scalar_axes) > 0:
from arkouda.numpy import squeeze

# reduce the array rank if there are any scalar indices
ret_array = _squeeze(maybe_degen_arr, scalar_axes)
ret_array = squeeze(maybe_degen_arr, tuple(scalar_axes))
else:
ret_array = maybe_degen_arr

Expand Down Expand Up @@ -2738,19 +2723,32 @@ def _reduces_to_single_value(axis, ndim) -> bool:


# helper function for sum, min, max, prod
def _comon_reduction(
pda: pdarray, axis: Optional[Union[int, Tuple[int, ...]]], kind: str
):
def _comon_reduction(pda: pdarray, axis: Optional[Union[int, Tuple[int, ...]]], kind: str):
if kind not in ["sum", "min", "max", "prod"]:
raise ValueError(f"Unsupported reduction type: {kind}")

axis_ = [] if axis is None else ([axis,] if isinstance(axis, int) else list(axis))
axis_ = (
[]
if axis is None
else (
[
axis,
]
if isinstance(axis, int)
else list(axis)
)
)

if _reduces_to_single_value(axis_, pda.ndim):
return parse_single_value(cast(str, generic_msg(
cmd=f"{kind}All<{pda.dtype.name},{pda.ndim}>",
args={"x": pda, "skipNan": False},
)))
return parse_single_value(
cast(
str,
generic_msg(
cmd=f"{kind}All<{pda.dtype.name},{pda.ndim}>",
args={"x": pda, "skipNan": False},
),
)
)
else:
return create_pdarray(
generic_msg(
Expand Down

0 comments on commit 5289b58

Please sign in to comment.