Skip to content

Commit

Permalink
Closes Bears-R-Us#3868: move squeeze functionality to arkouda.numpy.
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Oct 24, 2024
1 parent 2b71be1 commit cb00ee8
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 44 deletions.
54 changes: 11 additions & 43 deletions arkouda/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array:
raise ValueError(f"Failed to broadcast array: {e}")


def concat(
arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0
) -> Array:
def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array:
"""
Concatenate arrays along an axis.
Expand All @@ -70,9 +68,7 @@ def concat(
ndim = arrays[0].ndim
for a in arrays:
if a.ndim != ndim:
raise ValueError(
"all input arrays must have the same number of dimensions to concatenate"
)
raise ValueError("all input arrays must have the same number of dimensions to concatenate")

(common_dt, _arrays) = promote_to_common_dtype([a._array for a in arrays])

Expand Down Expand Up @@ -192,15 +188,11 @@ def moveaxis(
for s, d in zip(source, destination):
perm[s] = d
else:
raise ValueError(
"source and destination must both be tuples if source is a tuple"
)
raise ValueError("source and destination must both be tuples if source is a tuple")
elif isinstance(destination, int):
perm[source] = destination
else:
raise ValueError(
"source and destination must both be integers if source is a tuple"
)
raise ValueError("source and destination must both be integers if source is a tuple")

return permute_dims(x, axes=tuple(perm))

Expand Down Expand Up @@ -235,9 +227,7 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
raise IndexError(f"Failed to permute array dimensions: {e}")


def repeat(
x: Array, repeats: Union[int, Array], /, *, axis: Optional[int] = None
) -> Array:
def repeat(x: Array, repeats: Union[int, Array], /, *, axis: Optional[int] = None) -> Array:
"""
Repeat elements of an array.
Expand Down Expand Up @@ -277,9 +267,7 @@ def repeat(
raise NotImplementedError("repeat with 'axis' argument is not yet implemented")


def reshape(
x: Array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None
) -> Array:
def reshape(x: Array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None) -> Array:
"""
Reshape an array to a new shape.
Expand Down Expand Up @@ -355,9 +343,7 @@ def roll(
args={
"name": x._array,
"nShifts": len(shift) if isinstance(shift, tuple) else 1,
"shift": (
list(shift) if isinstance(shift, tuple) else [shift]
),
"shift": (list(shift) if isinstance(shift, tuple) else [shift]),
"nAxes": len(axisList),
"axis": axisList,
},
Expand All @@ -380,25 +366,9 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
axis : int or Tuple[int, ...]
The axis or axes to squeeze (must have a size of one).
"""
nAxes = len(axis) if isinstance(axis, tuple) else 1
try:
return Array._new(
create_pdarray(
cast(
str,
generic_msg(
cmd=f"squeeze<{x.dtype},{x.ndim},{x.ndim - nAxes}>",
args={
"name": x._array,
"nAxes": nAxes,
"axes": list(axis) if isinstance(axis, tuple) else [axis],
},
),
)
)
)
except RuntimeError as e:
raise ValueError(f"Failed to squeeze array: {e}")
from arkouda.numpy import squeeze

return Array._new(squeeze(x._array, axis))


def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
Expand All @@ -420,9 +390,7 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) ->
ndim = arrays[0].ndim
for a in arrays:
if a.ndim != ndim:
raise ValueError(
"all input arrays must have the same number of dimensions to stack"
)
raise ValueError("all input arrays must have the same number of dimensions to stack")

(common_dt, _arrays) = promote_to_common_dtype([a._array for a in arrays])

Expand Down
68 changes: 67 additions & 1 deletion arkouda/numpy/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from arkouda.categorical import Categorical


__all__ = ["flip"]
__all__ = ["flip", "squeeze"]


def flip(
Expand Down Expand Up @@ -86,3 +86,69 @@ def flip(
return Strings.from_return_msg(cast(str, rep_msg))
else:
raise TypeError("flip only accepts type pdarray, Strings, or Categorical.")


def squeeze(x: pdarray, /, axis: Union[None, int, Tuple[int, ...]]=None) -> pdarray:
"""
Remove degenerate (size one) dimensions from an array.
Parameters
----------
x : Array
The array to squeeze
axis : int or Tuple[int, ...]
The axis or axes to squeeze (must have a size of one).
If axis = None, all dimensions of size 1 will be squeezed.
Returns
-------
pdarray
A copy of x with the dimensions specified in the axis argument removed.
Examples
--------
>>> import arkouda as ak
>>> ak.connect()
>>> x = ak.arange(10).reshape((1, 10, 1))
>>> x
array([array([array([0]) array([1]) array([2]) array([3]) array([4]) array([5]) array([6]) array([7]) array([8]) array([9])])])
>>> x.shape
(1, 10, 1)
>>> ak.squeeze(x,axis=None)
array([0 1 2 3 4 5 6 7 8 9])
>>> ak.squeeze(x,axis=None).shape
(10,)
>>> ak.squeeze(x,axis=2)
array([array([0 1 2 3 4 5 6 7 8 9])])
>>> ak.squeeze(x,axis=2).shape
(1, 10)
>>> ak.squeeze(x,axis=(0,2))
array([0 1 2 3 4 5 6 7 8 9])
>>> ak.squeeze(x,axis=(0,2)).shape
(10,)
"""
if axis is None:
axis = tuple([i for i in range(x.ndim) if x.shape[i] == 1])

nAxes = len(axis) if isinstance(axis, tuple) else 1
try:
return create_pdarray(
cast(
str,
generic_msg(
cmd=f"squeeze<{x.dtype},{x.ndim},{x.ndim - nAxes}>",
args={
"name": x,
"nAxes": nAxes,
"axes": list(axis) if isinstance(axis, tuple) else [axis],
},
),
)
)

except RuntimeError as e:
raise ValueError(f"Failed to squeeze array: {e}")

0 comments on commit cb00ee8

Please sign in to comment.