Skip to content

Commit

Permalink
Fixes #2617: Bug in comparison of segarrays containing empty segments
Browse files Browse the repository at this point in the history
This PR (fixes #2617):
- Fixes bug in equality check of segarays containing empty segments. Due to bug in gen_ranges, not using the `selfcmp` groupby, and size mismatch since segarr.all() excludes empty segs
- Fixed issue with empty segs in gen_ranges
- Added support for indexing segarray with a uint pdarray
- Fixed error message in indexing and dataframe indexing
- Added tests for seggarray comparisons
  • Loading branch information
Pierce Hayes committed Aug 4, 2023
1 parent e322904 commit 2aa24e0
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 26 deletions.
18 changes: 4 additions & 14 deletions arkouda/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from arkouda.groupbyclass import GroupBy, broadcast
from arkouda.numeric import cumsum
from arkouda.pdarrayclass import create_pdarray, pdarray
from arkouda.pdarraycreation import arange, array, ones, zeros
from arkouda.pdarraycreation import arange, array, zeros
from arkouda.pdarraysetops import concatenate, in1d
from arkouda.segarray import gen_ranges as seg_gen_ranges
from arkouda.strings import Strings

__all__ = ["join_on_eq_with_dt", "gen_ranges", "compute_join_size"]
Expand Down Expand Up @@ -137,19 +138,8 @@ def gen_ranges(starts: pdarray, ends: pdarray) -> Tuple[pdarray, pdarray]:
ranges : pdarray, int64
The actual ranges, flattened into a single array
"""
if starts.size != ends.size:
raise ValueError("starts and ends must be same size")
if starts.size == 0:
return zeros(0, dtype=akint64), zeros(0, dtype=akint64)
lengths = ends - starts
if not (lengths > 0).all():
raise ValueError("all ends must be greater than starts")
segs = cumsum(lengths) - lengths
totlen = lengths.sum()
slices = ones(totlen, dtype=akint64)
diffs = concatenate((array([starts[0]]), starts[1:] - starts[:-1] - lengths[:-1] + 1))
slices[segs] = diffs
return segs, cumsum(slices)
# only maintain one version of gen_ranges
return seg_gen_ranges(starts, ends)


@typechecked
Expand Down
26 changes: 19 additions & 7 deletions arkouda/segarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from arkouda.dtypes import bool as akbool
from arkouda.dtypes import int64 as akint64
from arkouda.dtypes import isSupportedInt, str_
from arkouda.dtypes import uint64 as akuint64
from arkouda.groupbyclass import GroupBy, broadcast
from arkouda.logger import getArkoudaLogger
from arkouda.numeric import cumsum
Expand Down Expand Up @@ -56,11 +57,21 @@ def gen_ranges(starts, ends, stride=1):
if starts.size == 0:
return zeros(0, dtype=akint64), zeros(0, dtype=akint64)
lengths = (ends - starts) // stride
if not (lengths >= 0).all():
raise ValueError("all ends must be greater than or equal to starts")
non_empty = lengths != 0
segs = cumsum(lengths) - lengths
totlen = lengths.sum()
slices = ones(totlen, dtype=akint64)
diffs = concatenate((array([starts[0]]), starts[1:] - starts[:-1] - (lengths[:-1] - 1) * stride))
slices[segs] = diffs
non_empty_starts = starts[non_empty]
non_empty_lengths = lengths[non_empty]
diffs = concatenate(
(
array([non_empty_starts[0]]),
non_empty_starts[1:] - non_empty_starts[:-1] - (non_empty_lengths[:-1] - 1) * stride,
)
)
slices[segs[non_empty]] = diffs
return segs, cumsum(slices)


Expand Down Expand Up @@ -263,9 +274,7 @@ def __getitem__(self, i):
start = self.segments[i]
end = self.segments[i] + self.lengths[i]
return self.values[start:end]
elif (isinstance(i, pdarray) and (i.dtype == akint64 or i.dtype == akbool)) or isinstance(
i, slice
):
elif (isinstance(i, pdarray) and i.dtype in [akint64, akuint64, akbool]) or isinstance(i, slice):
starts = self.segments[i]
ends = starts + self.lengths[i]
newsegs, inds = gen_ranges(starts, ends)
Expand Down Expand Up @@ -352,13 +361,16 @@ def copy(self):
def __eq__(self, other):
if not isinstance(other, SegArray):
return NotImplemented
if self.size != other.size:
raise ValueError("Segarrays must have same size to compare")
eq = zeros(self.size, dtype=akbool)
leneq = self.lengths == other.lengths
if leneq.sum() > 0:
selfcmp = self[leneq]
othercmp = other[leneq]
intersection = self.all(selfcmp.values == othercmp.values)
eq[leneq] = intersection
intersection = selfcmp.all(selfcmp.values == othercmp.values)
eq[leneq & (self.lengths != 0)] = intersection
eq[leneq & (self.lengths == 0)] = True
return eq

def __len__(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion src/DataFrameIndexingMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
throw new owned IllegalArgumentError(errorMsg);
}
if idxMax >= columnVals.size {
var errorMsg = "Error: %s: OOBindex %i > %i".doFormat(pn,idxMin,columnVals.size-1);
var errorMsg = "Error: %s: OOBindex %i > %i".doFormat(pn,idxMax,columnVals.size-1);
dfiLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
throw new owned IllegalArgumentError(errorMsg);
}
Expand Down
4 changes: 2 additions & 2 deletions src/IndexingMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ module IndexingMsg
return new MsgTuple(errorMsg,MsgType.ERROR);
}
if ivMax >= e.size {
var errorMsg = "Error: %s: OOBindex %i > %i".doFormat(pn,ivMin,e.size-1);
var errorMsg = "Error: %s: OOBindex %i > %i".doFormat(pn,ivMax,e.size-1);
imLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return new MsgTuple(errorMsg,MsgType.ERROR);
}
Expand Down Expand Up @@ -399,7 +399,7 @@ module IndexingMsg
return new MsgTuple(errorMsg,MsgType.ERROR);
}
if ivMax >= e.size {
var errorMsg = "Error: %s: OOBindex %i > %i".doFormat(pn,ivMin,e.size-1);
var errorMsg = "Error: %s: OOBindex %i > %i".doFormat(pn,ivMax,e.size-1);
imLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return new MsgTuple(errorMsg,MsgType.ERROR);
}
Expand Down
77 changes: 75 additions & 2 deletions tests/segarray_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import tempfile

import numpy as np
from base_test import ArkoudaTest
from context import arkouda as ak

Expand Down Expand Up @@ -642,7 +643,7 @@ def test_filter(self):
# test filtering single value retain empties
filter_result = sa.filter(2, discard_empty=False)
self.assertEqual(sa.size, filter_result.size)
#ensure 2 does not exist in return values
# ensure 2 does not exist in return values
self.assertTrue((filter_result.values != 2).all())
for i in range(sa.size):
self.assertListEqual(sa[i][(sa[i] != 2)].to_list(), filter_result[i].to_list())
Expand Down Expand Up @@ -680,7 +681,79 @@ def test_filter(self):
x = ak.in1d(ak.array(sa[i]), ak.array([1, 2]), invert=True)
v = ak.array(sa[i])[x]
if v.size != 0:
self.assertListEqual(v.to_list(), filter_result[i-offset].to_list())
self.assertListEqual(v.to_list(), filter_result[i - offset].to_list())
else:
offset += 1

def test_equality(self):
# reproducer for issue #2617
# verify equality no matter position of empty seg
for has_empty_seg in [0, 0, 9, 14], [0, 9, 9, 14, 14], [0, 0, 7, 9, 14, 14, 17, 20]:
sa = ak.SegArray(ak.array(has_empty_seg), ak.arange(-10, 10))
self.assertTrue((sa == sa).all())

s1 = ak.SegArray(ak.array([0, 4, 14, 14]), ak.arange(-10, 10))
s2 = ak.SegArray(ak.array([0, 9, 14, 14]), ak.arange(-10, 10))
self.assertTrue((s1 == s2).to_list() == [False, False, True, True])

# test segarrays with empty segments, multiple types, and edge cases
df = ak.DataFrame(
{
"c_1": ak.SegArray(ak.array([0, 0, 9, 14]), ak.arange(-10, 10)),
"c_2": ak.SegArray(
ak.array([0, 5, 10, 10]), ak.arange(2**63, 2**63 + 15, dtype=ak.uint64)
),
"c_3": ak.SegArray(ak.array([0, 0, 5, 10]), ak.randint(0, 1, 15, dtype=ak.bool)),
"c_4": ak.SegArray(
ak.array([0, 9, 14, 14]),
ak.array(
[
np.nan,
np.finfo(np.float64).min,
-np.inf,
-7.0,
-3.14,
-0.0,
0.0,
3.14,
7.0,
np.finfo(np.float64).max,
np.inf,
np.nan,
np.nan,
np.nan,
]
),
),
"c_5": ak.SegArray(
ak.array([0, 2, 5, 5]), ak.array(["a", "b", "c", "d", "e", "f", "g", "h", "i"])
),
"c_6": ak.SegArray(
ak.array([0, 2, 2, 2]), ak.array(["a", "b", "", "c", "d", "e", "f", "g", "h", "i"])
),
"c_7": ak.SegArray(
ak.array([0, 0, 2, 2]), ak.array(["a", "b", "c", "d", "e", "f", "g", "h", "i"])
),
"c_8": ak.SegArray(
ak.array([0, 2, 3, 3]), ak.array(["", "'", " ", "test", "", "'", "", " ", ""])
),
"c_9": ak.SegArray(
ak.array([0, 5, 5, 8]), ak.array(["a", "b", "c", "d", "e", "f", "g", "h", "i"])
),
"c_10": ak.SegArray(
ak.array([0, 5, 8, 8]),
ak.array(["abc", "123", "xyz", "l", "m", "n", "o", "p", "arkouda"]),
),
}
)

for col in df.columns:
a = df[col]
if a.dtype == ak.float64:
a = a.to_ndarray()
if isinstance(a[0], np.ndarray):
self.assertTrue(all(np.allclose(a1, b1, equal_nan=True) for a1, b1 in zip(a, a)))
else:
self.assertTrue(np.allclose(a, a, equal_nan=True))
else:
self.assertTrue((a == a).all())

0 comments on commit 2aa24e0

Please sign in to comment.