Skip to content

Commit

Permalink
Closes Bears-R-Us#3823 flatten function to match numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Oct 9, 2024
1 parent be9f7ab commit 17084d7
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 5 deletions.
1 change: 1 addition & 0 deletions ServerModules.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

ArgSortMsg
ArraySetopsMsg
AryUtil
BroadcastMsg
CastMsg
ConcatenateMsg
Expand Down
4 changes: 2 additions & 2 deletions arkouda/numpy/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def flip(

elif isinstance(x, Strings):
rep_msg = generic_msg(
cmd="flipString", args={"objType": x.objType, "obj": x.entry, "size": x.size}
)
cmd="flipString", args={"objType": x.objType, "obj": x.entry, "size": x.size}
)
return Strings.from_return_msg(cast(str, rep_msg))
else:
raise TypeError("flip only accepts type pdarray, Strings, or Categorical.")
19 changes: 16 additions & 3 deletions arkouda/pdarrayclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,8 +521,7 @@ def _binop(self, other: pdarray, op: str) -> pdarray:
except ValueError:
raise ValueError(f"shape mismatch {self.shape} {other.shape}")
repMsg = generic_msg(
cmd=f"binopvv<{self.dtype},{other.dtype},{x1.ndim}>",
args={"op": op, "a": x1, "b": x2}
cmd=f"binopvv<{self.dtype},{other.dtype},{x1.ndim}>", args={"op": op, "a": x1, "b": x2}
)
if tmp_x1:
del x1
Expand Down Expand Up @@ -779,7 +778,7 @@ def opeq(self, other, op):
raise ValueError(f"shape mismatch {self.shape} {other.shape}")
generic_msg(
cmd=f"opeqvv<{self.dtype},{other.dtype},{self.ndim}>",
args={"op": op, "a": self, "b": other}
args={"op": op, "a": self, "b": other},
)
return self
# pdarray binop scalar
Expand Down Expand Up @@ -1757,6 +1756,20 @@ def reshape(self, *shape):
),
)

def flatten(self):
"""
Return a copy of the array collapsed into one dimension.
Returns
-------
A copy of the input array, flattened to one dimension.
"""
return create_pdarray(
generic_msg(
cmd=f"flatten<{self.dtype.name},{self.ndim}>",
args={"a": self},
)
)

def to_ndarray(self) -> np.ndarray:
"""
Convert the array to a np.ndarray, transferring array data from the
Expand Down
14 changes: 14 additions & 0 deletions arkouda/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2082,6 +2082,20 @@ def _get_grouping_keys(self) -> List[Strings]:
"""
return [self]

def flatten(self):
"""
Return a copy of the array collapsed into one dimension.
Returns
-------
A copy of the input array, flattened to one dimension.
Note
----
Since multidimensional Strings are currently supported,
flatten on a Strings will always return itself.
"""
return self

def to_ndarray(self) -> np.ndarray:
"""
Convert the array to a np.ndarray, transferring array data from the
Expand Down
7 changes: 7 additions & 0 deletions src/AryUtil.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,7 @@ module AryUtil
/*
flatten a multi-dimensional array into a 1D array
*/
@arkouda.registerCommand
proc flatten(const ref a: [?d] ?t): [] t throws
where a.rank > 1
{
Expand Down Expand Up @@ -1003,6 +1004,12 @@ module AryUtil
return flat;
}

proc flatten(const ref a: [?d] ?t): [] t throws
where a.rank == 1
{
return a;
}

// helper for computing an array element's index from its order
record orderer {
param rank: int;
Expand Down
14 changes: 14 additions & 0 deletions tests/pdarrayclass_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

import arkouda as ak
from arkouda.testing import assert_equal as ak_assert_equal


class TestPdarrayClass:
Expand All @@ -11,3 +12,16 @@ def test_reshape(self):
r = a.reshape((2, 2))
assert r.shape == [2, 2]
assert isinstance(r, ak.pdarray)

@pytest.mark.parametrize("size", pytest.prob_size)
def test_flatten(self, size):
a = ak.arange(size)
ak_assert_equal(a.flatten(), a)

@pytest.mark.skip_if_max_rank_less_than(3)
@pytest.mark.parametrize("size", pytest.prob_size)
def test_flatten(self, size):
size = size - (size % 4)
a = ak.arange(size)
b = a.reshape((2, 2, size / 4))
ak_assert_equal(b.flatten(), a)
8 changes: 8 additions & 0 deletions tests/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import arkouda as ak
from arkouda.testing import assert_equal as ak_assert_equal

ak.verbose = False
N = 100
Expand Down Expand Up @@ -910,3 +911,10 @@ def test_in1d(self, size):
for word in more_words:
inds |= strings == word
assert (inds == matches).all()

@pytest.mark.parametrize("size", pytest.prob_size)
def test_flatten(self, size):
base_words, _ = self.base_words(size)
strings = self.get_strings(size, base_words)

ak_assert_equal(strings.flatten(), strings)

0 comments on commit 17084d7

Please sign in to comment.