Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1406 Fix bug in ak.lookup #1407

Merged
merged 2 commits into from
May 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion arkouda/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def in1dmulti(a, b, assume_unique=False, symmetric=False):
return atruth


def lookup(keys, values, arguments, fillvalue=-1):
def lookup(keys, values, arguments, fillvalue=-1, keys_from_unique=False):
"""
Apply the function defined by the mapping keys --> values to arguments.
Expand All @@ -187,6 +187,9 @@ def lookup(keys, values, arguments, fillvalue=-1):
(or tuple of dtypes, for a sequence) as keys.
fillvalue : scalar
The default value to return for arguments not in keys.
keys_from_unique : bool
If True, keys are assumed to be the output of ak.unique, e.g. the
.unique_keys attribute of a GroupBy instance.
Returns
-------
Expand Down Expand Up @@ -219,6 +222,12 @@ def lookup(keys, values, arguments, fillvalue=-1):
(array(['twenty', 'twenty', 'twenty']),
array(['four', 'one', 'two']))
"""
if not keys_from_unique:
keyg = GroupBy(keys)
if keyg.size != keyg.ngroups:
raise NonUniqueError("Function keys must be unique.")
keys = keyg.unique_keys
values = values[keyg.permutation]
if isinstance(values, Categorical):
codes = lookup(keys, values.codes, arguments, fillvalue=values._NAcode)
return Categorical.from_codes(codes, values.categories, NAvalue=values.NAvalue)
Expand Down
21 changes: 21 additions & 0 deletions tests/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,27 @@ def test_inner_join(self):
with self.assertRaises(ValueError):
l, r = ak.join.inner_join(left, right, wherefunc=ak.intersect1d, whereargs=(ak.arange(10), ak.arange(5)))

def test_lookup(self):
keys = ak.arange(5)
values = 10*keys
args = ak.array([5, 3, 1, 4, 2, 3, 1, 0])
ans = np.array([-1, 30, 10, 40, 20, 30, 10, 0])
# Simple lookup with int keys
# Also test shortcut for unique-ordered keys
res = ak.lookup(keys, values, args, fillvalue=-1, keys_from_unique=True)
self.assertTrue((res.to_ndarray() == ans).all())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's not consistent across our tests but i prefer

self.assertListEqual(res.to_ndarray().tolist(), ans.tolist())

because it gives more precise information about where the lists are different

for example

>       self.assertTrue((ak.array([1,2,3]) == ak.array([1,1,3])).all())
E       AssertionError: False is not true

vs

>       self.assertListEqual(ak.array([1,2,3]).to_ndarray().tolist(), ak.array([1,1,3]).to_ndarray().tolist())
E       AssertionError: Lists differ: [1, 2, 3] != [1, 1, 3]
E       
E       First differing element 1:
E       2
E       1
E       
E       - [1, 2, 3]
E       ?     ^
E       
E       + [1, 1, 3]
E       ?     ^

# Compound lookup with (str, int) keys
res2 = ak.lookup((ak.cast(keys, ak.str_), keys), values, (ak.cast(args, ak.str_), args), fillvalue=-1)
self.assertTrue((res2.to_ndarray() == ans).all())
# Keys not in uniqued order
res3 = ak.lookup(keys[::-1], values[::-1], args, fillvalue=-1)
self.assertTrue((res3.to_ndarray() == ans).all())
# Non-unique keys should raise error
with self.assertRaises(ak.NonUniqueError):
keys = ak.arange(10) % 5
values = 10 * keys
ak.lookup(keys, values, args)

def test_error_handling(self):
"""
Tests error TypeError and ValueError handling
Expand Down