Skip to content

Commit

Permalink
Closes Bears-R-Us#3868: move squeeze functionality to arkouda.numpy.
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Oct 24, 2024
1 parent 2b71be1 commit 6c82eaa
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 52 deletions.
54 changes: 11 additions & 43 deletions arkouda/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array:
raise ValueError(f"Failed to broadcast array: {e}")


def concat(
arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0
) -> Array:
def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array:
"""
Concatenate arrays along an axis.
Expand All @@ -70,9 +68,7 @@ def concat(
ndim = arrays[0].ndim
for a in arrays:
if a.ndim != ndim:
raise ValueError(
"all input arrays must have the same number of dimensions to concatenate"
)
raise ValueError("all input arrays must have the same number of dimensions to concatenate")

(common_dt, _arrays) = promote_to_common_dtype([a._array for a in arrays])

Expand Down Expand Up @@ -192,15 +188,11 @@ def moveaxis(
for s, d in zip(source, destination):
perm[s] = d
else:
raise ValueError(
"source and destination must both be tuples if source is a tuple"
)
raise ValueError("source and destination must both be tuples if source is a tuple")
elif isinstance(destination, int):
perm[source] = destination
else:
raise ValueError(
"source and destination must both be integers if source is a tuple"
)
raise ValueError("source and destination must both be integers if source is a tuple")

return permute_dims(x, axes=tuple(perm))

Expand Down Expand Up @@ -235,9 +227,7 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
raise IndexError(f"Failed to permute array dimensions: {e}")


def repeat(
x: Array, repeats: Union[int, Array], /, *, axis: Optional[int] = None
) -> Array:
def repeat(x: Array, repeats: Union[int, Array], /, *, axis: Optional[int] = None) -> Array:
"""
Repeat elements of an array.
Expand Down Expand Up @@ -277,9 +267,7 @@ def repeat(
raise NotImplementedError("repeat with 'axis' argument is not yet implemented")


def reshape(
x: Array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None
) -> Array:
def reshape(x: Array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None) -> Array:
"""
Reshape an array to a new shape.
Expand Down Expand Up @@ -355,9 +343,7 @@ def roll(
args={
"name": x._array,
"nShifts": len(shift) if isinstance(shift, tuple) else 1,
"shift": (
list(shift) if isinstance(shift, tuple) else [shift]
),
"shift": (list(shift) if isinstance(shift, tuple) else [shift]),
"nAxes": len(axisList),
"axis": axisList,
},
Expand All @@ -380,25 +366,9 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
axis : int or Tuple[int, ...]
The axis or axes to squeeze (must have a size of one).
"""
nAxes = len(axis) if isinstance(axis, tuple) else 1
try:
return Array._new(
create_pdarray(
cast(
str,
generic_msg(
cmd=f"squeeze<{x.dtype},{x.ndim},{x.ndim - nAxes}>",
args={
"name": x._array,
"nAxes": nAxes,
"axes": list(axis) if isinstance(axis, tuple) else [axis],
},
),
)
)
)
except RuntimeError as e:
raise ValueError(f"Failed to squeeze array: {e}")
from arkouda.numpy import squeeze

return Array._new(squeeze(x._array, axis))


def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
Expand All @@ -420,9 +390,7 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) ->
ndim = arrays[0].ndim
for a in arrays:
if a.ndim != ndim:
raise ValueError(
"all input arrays must have the same number of dimensions to stack"
)
raise ValueError("all input arrays must have the same number of dimensions to stack")

(common_dt, _arrays) = promote_to_common_dtype([a._array for a in arrays])

Expand Down
97 changes: 88 additions & 9 deletions arkouda/numpy/_manipulation_functions.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
# from __future__ import annotations

from typing import Optional
from typing import Tuple
from typing import Union
from typing import cast
from typing import Optional, Tuple, Union, cast

from typeguard import typechecked

from arkouda.categorical import Categorical
from arkouda.client import generic_msg
from arkouda.pdarrayclass import create_pdarray
from arkouda.pdarrayclass import pdarray
from arkouda.numpy.dtypes import numeric_scalars, bool_scalars
from arkouda.pdarrayclass import create_pdarray, pdarray
from arkouda.pdarraycreation import array as ak_array
from arkouda.strings import Strings
from arkouda.categorical import Categorical


__all__ = ["flip"]
__all__ = ["flip", "squeeze"]


def flip(
Expand Down Expand Up @@ -86,3 +85,83 @@ def flip(
return Strings.from_return_msg(cast(str, rep_msg))
else:
raise TypeError("flip only accepts type pdarray, Strings, or Categorical.")


@typechecked
def squeeze(
x: Union[pdarray, numeric_scalars, bool_scalars], /, axis: Union[None, int, Tuple[int, ...]] = None
) -> pdarray:
"""
Remove degenerate (size one) dimensions from an array.
Parameters
----------
x : pdarray
The array to squeeze
axis : int or Tuple[int, ...]
The axis or axes to squeeze (must have a size of one).
If axis = None, all dimensions of size 1 will be squeezed.
Returns
-------
pdarray
A copy of x with the dimensions specified in the axis argument removed.
Examples
--------
>>> import arkouda as ak
>>> ak.connect()
>>> x = ak.arange(10).reshape((1, 10, 1))
>>> x
array([array([array([0]) array([1]) array([2]) array([3])....
array([4]) array([5]) array([6]) array([7]) array([8]) array([9])])])
>>> x.shape
(1, 10, 1)
>>> ak.squeeze(x,axis=None)
array([0 1 2 3 4 5 6 7 8 9])
>>> ak.squeeze(x,axis=None).shape
(10,)
>>> ak.squeeze(x,axis=2)
array([array([0 1 2 3 4 5 6 7 8 9])])
>>> ak.squeeze(x,axis=2).shape
(1, 10)
>>> ak.squeeze(x,axis=(0,2))
array([0 1 2 3 4 5 6 7 8 9])
>>> ak.squeeze(x,axis=(0,2)).shape
(10,)
"""
from arkouda.numpy.dtypes import _val_isinstance_of_union

if _val_isinstance_of_union(x, numeric_scalars) or _val_isinstance_of_union(x, bool_scalars):
ret = ak_array([x])
if isinstance(ret, pdarray):
return ret

if isinstance(x, pdarray):
if axis is None:
_axis = [i for i in range(x.ndim) if x.shape[i] == 1]
# Can't squeeze over every dimension, so remove one if necessary
if len(_axis) == len(x.shape):
_axis.pop()
axis = tuple(_axis)

nAxes = len(axis) if isinstance(axis, tuple) else 1
try:
return create_pdarray(
cast(
str,
generic_msg(
cmd=f"squeeze<{x.dtype},{x.ndim},{x.ndim - nAxes}>",
args={
"name": x,
"nAxes": nAxes,
"axes": list(axis) if isinstance(axis, tuple) else [axis],
},
),
)
)

except RuntimeError as e:
raise ValueError(f"Failed to squeeze array: {e}")

raise RuntimeError("Failed to squeeze array.")
34 changes: 34 additions & 0 deletions tests/numpy/manipulation_functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

seed = pytest.seed

DTYPES = ["uint64", "uint8", "int64", "float64", "bigint", "bool"]


class TestNumpyManipulationFunctions:

Expand Down Expand Up @@ -54,3 +56,35 @@ def test_flip_categorical(self, size):
# test case when c.permutation = None
c2 = Categorical(c.to_pandas())
assert_equal(ak.flip(c2), c2[::-1])

@pytest.mark.parametrize("size", pytest.prob_size)
@pytest.mark.parametrize("dtype", DTYPES)
def test_squeeze_1D(self, size, dtype):
x = ak.arange(size, dtype=dtype)
assert_equal(ak.squeeze(x), ak.arange(size, dtype=dtype))

y = 1
assert_equal(ak.squeeze(y), ak.array([1]))

z = ak.array([1])
assert_equal(ak.squeeze(z), ak.array([1]))

@pytest.mark.skip_if_max_rank_less_than(3)
@pytest.mark.parametrize("size", pytest.prob_size)
@pytest.mark.parametrize("dtype", DTYPES)
def test_squeeze(self, size, dtype):

if dtype == "bigint":
pytest.skip("Skip until #3870 is resolved.")

x = ak.arange(size, dtype=dtype).reshape((1, size, 1))
assert_equal(ak.squeeze(x, axis=None), ak.arange(size, dtype=dtype))
assert_equal(ak.squeeze(x, axis=0), ak.arange(size, dtype=dtype).reshape((size, 1)))
assert_equal(ak.squeeze(x, axis=2), ak.arange(size, dtype=dtype).reshape((1, size)))
assert_equal(ak.squeeze(x, axis=(0, 2)), ak.arange(size, dtype=dtype))

y = 1
assert_equal(ak.squeeze(y), ak.array([1]))

z = ak.array([[[1]]])
assert_equal(ak.squeeze(z), ak.array([1]))

0 comments on commit 6c82eaa

Please sign in to comment.