Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closes #1394: Fancy pdarray type preservation in ak.concatenate #1402

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions arkouda/pdarraysetops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from arkouda.dtypes import uint64 as akuint64
from arkouda.dtypes import int64 as akint64
from arkouda.dtypes import bool as akbool
from arkouda.client_dtypes import BitVector
from arkouda.groupbyclass import unique, GroupBy, groupable, groupable_element_type

Categorical = ForwardRef('Categorical')
Expand Down Expand Up @@ -150,6 +151,7 @@ def concatenate(arrays: Sequence[Union[pdarray, Strings, 'Categorical', ]], # t
"""
from arkouda.categorical import Categorical as Categorical_
from arkouda.dtypes import int_scalars
from arkouda.util import get_callback
size: int_scalars = 0
objtype = None
dtype = None
Expand All @@ -160,8 +162,21 @@ def concatenate(arrays: Sequence[Union[pdarray, Strings, 'Categorical', ]], # t
mode = 'interleave'
if len(arrays) < 1:
raise ValueError("concatenate called on empty iterable")
callback = get_callback(list(arrays)[0])
if len(arrays) == 1:
return cast(groupable_element_type, arrays[0])
# return object as it's original type
return callback(arrays[0])

types = set([type(x) for x in arrays])
if len(types) != 1:
raise TypeError(f"Items must all have same type: {types}")

if isinstance(arrays[0], BitVector):
# everything should be a BitVector because all have the same type, but do isinstance for mypy's sake
widths = set([x.width for x in arrays if isinstance(x, BitVector)])
revs = set([x.reverse for x in arrays if isinstance(x, BitVector)])
if len(widths) != 1 or len(revs) != 1:
raise TypeError("BitVectors must all have same width and direction")

if hasattr(arrays[0], 'concatenate'):
return cast(Sequence[Categorical_],
Expand All @@ -188,14 +203,14 @@ def concatenate(arrays: Sequence[Union[pdarray, Strings, 'Categorical', ]], # t
size += a.size
if size == 0:
if objtype == "pdarray":
return zeros_like(cast(pdarray, arrays[0]))
return callback(zeros_like(cast(pdarray, arrays[0])))
else:
return arrays[0]

repMsg = generic_msg(cmd="concatenate", args="{} {} {} {}". \
format(len(arrays), objtype, mode, ' '.join(names)))
if objtype == "pdarray":
return create_pdarray(cast(str, repMsg))
return callback(create_pdarray(cast(str, repMsg)))
elif objtype == "str":
# ConcatenateMsg returns created attrib(name)+created nbytes=123
return Strings.from_return_msg(cast(str, repMsg))
Expand Down
27 changes: 7 additions & 20 deletions arkouda/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from arkouda.client_dtypes import BitVector, BitVectorizer, IPv4
from arkouda.timeclass import Datetime, Timedelta
from arkouda.pdarrayclass import attach_pdarray, pdarray, create_pdarray
from arkouda.pdarraysetops import concatenate as pdarrayconcatenate
from arkouda.pdarraycreation import arange
from arkouda.pdarraysetops import unique
from arkouda.pdarrayIO import read
Expand All @@ -35,25 +34,13 @@ def get_callback(x):

# TODO - moving this into arkouda, function name should probably be changed.
def concatenate(items, ordered=True):
if len(items) > 0:
types = set([type(x) for x in items])
if len(types) != 1:
raise TypeError("Items must all have same type: {}".format(types))
t = types.pop()
if t == BitVector:
widths = set([x.width for x in items])
revs = set([x.reverse for x in items])
if len(widths) != 1 or len(revs) != 1:
raise TypeError("BitVectors must all have same width and direction")
callback = get_callback(list(items)[0])
if hasattr(t, 'concat'):
concat = t.concat
else:
concat = pdarrayconcatenate
else:
callback = identity
concat = pdarrayconcatenate
return callback(concat(items, ordered=ordered))
# this version can be called with Dataframe and Series (which have Class.concat methods)
from arkouda.pdarraysetops import concatenate as pdarrayconcatenate
types = set([type(x) for x in items])
if len(types) != 1:
raise TypeError(f"Items must all have same type: {types}")
t = types.pop()
return t.concat(items, ordered=ordered) if hasattr(t, 'concat') else pdarrayconcatenate(items, ordered=ordered)


def report_mem(pre=''):
Expand Down
74 changes: 73 additions & 1 deletion tests/operator_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,79 @@ def testConcatenate(self):

self.assertTrue((ak.array([True, False, True, False, True, True]) ==
ak.concatenate([pdaOne,pdaTwo])).all())


def test_concatenate_type_preservation(self):
# Test that concatenate preserves special pdarray types (IPv4, Datetime, BitVector, ...)
from arkouda.util import concatenate as akuconcat
pda_one = ak.arange(1, 4)
pda_two = ak.arange(4, 7)
pda_concat = ak.concatenate([pda_one, pda_two])

# IPv4 test
ipv4_one = ak.IPv4(pda_one)
ipv4_two = ak.IPv4(pda_two)
ipv4_concat = ak.concatenate([ipv4_one, ipv4_two])
self.assertEqual(type(ipv4_concat), ak.IPv4)
self.assertListEqual(ak.IPv4(pda_concat).to_ndarray().tolist(), ipv4_concat.to_ndarray().tolist())
# test single and empty
self.assertEqual(type(ak.concatenate([ipv4_one])), ak.IPv4)
self.assertListEqual(ak.IPv4(pda_one).to_ndarray().tolist(), ak.concatenate([ipv4_one]).to_ndarray().tolist())
self.assertEqual(type(ak.concatenate([ak.IPv4(ak.array([], dtype=ak.int64))])), ak.IPv4)

# Datetime test
datetime_one = ak.Datetime(pda_one)
datetime_two = ak.Datetime(pda_two)
datetime_concat = ak.concatenate([datetime_one, datetime_two])
self.assertEqual(type(datetime_concat), ak.Datetime)
self.assertListEqual(ak.Datetime(pda_concat).to_ndarray().tolist(), datetime_concat.to_ndarray().tolist())
# test single and empty
self.assertEqual(type(ak.concatenate([datetime_one])), ak.Datetime)
self.assertListEqual(ak.Datetime(pda_one).to_ndarray().tolist(), ak.concatenate([datetime_one]).to_ndarray().tolist())
self.assertEqual(type(ak.concatenate([ak.Datetime(ak.array([], dtype=ak.int64))])), ak.Datetime)

# Timedelta test
timedelta_one = ak.Timedelta(pda_one)
timedelta_two = ak.Timedelta(pda_two)
timedelta_concat = ak.concatenate([timedelta_one, timedelta_two])
self.assertEqual(type(timedelta_concat), ak.Timedelta)
self.assertListEqual(ak.Timedelta(pda_concat).to_ndarray().tolist(), timedelta_concat.to_ndarray().tolist())
# test single and empty
self.assertEqual(type(ak.concatenate([timedelta_one])), ak.Timedelta)
self.assertListEqual(ak.Timedelta(pda_one).to_ndarray().tolist(), ak.concatenate([timedelta_one]).to_ndarray().tolist())
self.assertEqual(type(ak.concatenate([ak.Timedelta(ak.array([], dtype=ak.int64))])), ak.Timedelta)

# BitVector test
bitvector_one = ak.BitVector(pda_one)
bitvector_two = ak.BitVector(pda_two)
bitvector_concat = ak.concatenate([bitvector_one, bitvector_two])
self.assertEqual(type(bitvector_concat), ak.BitVector)
self.assertListEqual(ak.BitVector(pda_concat).to_ndarray().tolist(), bitvector_concat.to_ndarray().tolist())
# test single and empty
self.assertEqual(type(ak.concatenate([bitvector_one])), ak.BitVector)
self.assertListEqual(ak.BitVector(pda_one).to_ndarray().tolist(), ak.concatenate([bitvector_one]).to_ndarray().tolist())
self.assertEqual(type(ak.concatenate([ak.BitVector(ak.array([], dtype=ak.int64))])), ak.BitVector)

# Test failure with mixed types
with self.assertRaises(TypeError):
ak.concatenate(datetime_one, bitvector_two)

# verify ak.util.concatenate still works
ipv4_akuconcat = akuconcat([ipv4_one, ipv4_two])
self.assertEqual(type(ipv4_akuconcat), ak.IPv4)
self.assertListEqual(ak.IPv4(pda_concat).to_ndarray().tolist(), ipv4_akuconcat.to_ndarray().tolist())

datetime_akuconcat = akuconcat([datetime_one, datetime_two])
self.assertEqual(type(datetime_akuconcat), ak.Datetime)
self.assertListEqual(ak.Datetime(pda_concat).to_ndarray().tolist(), datetime_akuconcat.to_ndarray().tolist())

timedelta_akuconcat = akuconcat([timedelta_one, timedelta_two])
self.assertEqual(type(timedelta_akuconcat), ak.Timedelta)
self.assertListEqual(ak.Timedelta(pda_concat).to_ndarray().tolist(), timedelta_akuconcat.to_ndarray().tolist())

bitvector_akuconcat = akuconcat([bitvector_one, bitvector_two])
self.assertEqual(type(bitvector_akuconcat), ak.BitVector)
self.assertListEqual(ak.BitVector(pda_concat).to_ndarray().tolist(), bitvector_akuconcat.to_ndarray().tolist())

def testAllOperators(self):
run_tests(verbose)

Expand Down