diff --git a/arkouda/alignment.py b/arkouda/alignment.py index d0fcc5268e..484a4c209e 100644 --- a/arkouda/alignment.py +++ b/arkouda/alignment.py @@ -171,7 +171,7 @@ def in1dmulti(a, b, assume_unique=False, symmetric=False): return atruth -def lookup(keys, values, arguments, fillvalue=-1): +def lookup(keys, values, arguments, fillvalue=-1, keys_from_unique=False): """ Apply the function defined by the mapping keys --> values to arguments. @@ -187,6 +187,9 @@ def lookup(keys, values, arguments, fillvalue=-1): (or tuple of dtypes, for a sequence) as keys. fillvalue : scalar The default value to return for arguments not in keys. + keys_from_unique : bool + If True, keys are assumed to be the output of ak.unique, e.g. the + .unique_keys attribute of a GroupBy instance. Returns ------- @@ -219,6 +222,12 @@ def lookup(keys, values, arguments, fillvalue=-1): (array(['twenty', 'twenty', 'twenty']), array(['four', 'one', 'two'])) """ + if not keys_from_unique: + keyg = GroupBy(keys) + if keyg.size != keyg.ngroups: + raise NonUniqueError("Function keys must be unique.") + keys = keyg.unique_keys + values = values[keyg.permutation] if isinstance(values, Categorical): codes = lookup(keys, values.codes, arguments, fillvalue=values._NAcode) return Categorical.from_codes(codes, values.categories, NAvalue=values.NAvalue) diff --git a/tests/join_test.py b/tests/join_test.py index cff8806887..d5f793f093 100755 --- a/tests/join_test.py +++ b/tests/join_test.py @@ -105,6 +105,27 @@ def test_inner_join(self): with self.assertRaises(ValueError): l, r = ak.join.inner_join(left, right, wherefunc=ak.intersect1d, whereargs=(ak.arange(10), ak.arange(5))) + def test_lookup(self): + keys = ak.arange(5) + values = 10*keys + args = ak.array([5, 3, 1, 4, 2, 3, 1, 0]) + ans = np.array([-1, 30, 10, 40, 20, 30, 10, 0]) + # Simple lookup with int keys + # Also test shortcut for unique-ordered keys + res = ak.lookup(keys, values, args, fillvalue=-1, keys_from_unique=True) + self.assertTrue((res.to_ndarray() == ans).all()) + # Compound lookup with (str, int) keys + res2 = ak.lookup((ak.cast(keys, ak.str_), keys), values, (ak.cast(args, ak.str_), args), fillvalue=-1) + self.assertTrue((res2.to_ndarray() == ans).all()) + # Keys not in uniqued order + res3 = ak.lookup(keys[::-1], values[::-1], args, fillvalue=-1) + self.assertTrue((res3.to_ndarray() == ans).all()) + # Non-unique keys should raise error + with self.assertRaises(ak.NonUniqueError): + keys = ak.arange(10) % 5 + values = 10 * keys + ak.lookup(keys, values, args) + def test_error_handling(self): """ Tests error TypeError and ValueError handling