diff --git a/src/biocutils/normalize_subscript.py b/src/biocutils/normalize_subscript.py index 3514149..fa699d7 100644 --- a/src/biocutils/normalize_subscript.py +++ b/src/biocutils/normalize_subscript.py @@ -13,7 +13,7 @@ def _raise_int(idx: int, length): pass -def normalize_subscript(sub: Union[slice, range, Sequence, int, str, bool], length: int, names: Optional[Sequence[str]] = None) -> Tuple: +def normalize_subscript(sub: Union[slice, range, Sequence, int, str, bool], length: int, names: Optional[Sequence[str]] = None, non_negative_only: bool = True) -> Tuple: """ Normalize a subscript for ``__getitem__`` or friends into a sequence of integer indices, for consistent downstream use. @@ -45,6 +45,10 @@ def normalize_subscript(sub: Union[slice, range, Sequence, int, str, bool], leng List of names for each entry in the object. If not None, this should have length equal to ``length``. + non_negative_only: + Whether negative indices must be converted into non-negative + equivalents. Setting this to `False` may improve efficiency. + Returns: A tuple containing (i) a sequence of integer indices in ``[0, length)`` specifying the subscript elements, and (ii) a boolean indicating whether @@ -59,7 +63,7 @@ def normalize_subscript(sub: Union[slice, range, Sequence, int, str, bool], leng if isinstance(sub, int) or (has_numpy and isinstance(sub, numpy.generic)): if sub < -length or sub >= length: _raise_int(sub, length) - if sub < 0: + if sub < 0 and non_negative_only: sub += length return [int(sub)], True @@ -85,26 +89,29 @@ def normalize_subscript(sub: Union[slice, range, Sequence, int, str, bool], leng if last < -length: _raise_int(last, length) - if sub.start < 0: - if sub.stop < 0: - return range(length + sub.start, length + sub.stop, sub.step), False - else: - return [ (x < 0) * length + x for x in sub], False + if not non_negative_only: + return sub, False else: - if sub.stop < 0: - return [ (x < 0) * length + x for x in sub], False + if sub.start < 0: + if sub.stop < 0: + return range(length + sub.start, length + sub.stop, sub.step), False + else: + return [ (x < 0) * length + x for x in sub], False else: - return sub, False + if sub.stop < 0: + return [ (x < 0) * length + x for x in sub], False + else: + return sub, False - can_return_early = False + can_return_early = True for x in sub: - if isinstance(x, str) or isinstance(x, bool) or (has_numpy and isinstance(x, numpy.bool_)) or x < 0: + if isinstance(x, str) or isinstance(x, bool) or (has_numpy and isinstance(x, numpy.bool_)) or (x < 0 and non_negative_only): can_return_early = False; break if can_return_early: for x in sub: - if x >= length: + if x >= length or x < -length: _raise_int(x, length) return sub, False @@ -122,11 +129,11 @@ def normalize_subscript(sub: Union[slice, range, Sequence, int, str, bool], leng elif x < 0: if x < -length: _raise_int(x, length) - output.append(x + length) + output.append(int(x) + length) else: if x >= length: _raise_int(x, length) - output.append(x) + output.append(int(x)) if len(has_strings): if names is None: diff --git a/tests/test_normalize_subscript.py b/tests/test_normalize_subscript.py index a6fc02d..3516ae7 100644 --- a/tests/test_normalize_subscript.py +++ b/tests/test_normalize_subscript.py @@ -99,3 +99,14 @@ def test_normalize_subscript_numpy(): # Now the trickiest part - are booleans converted correctly? assert normalize_subscript(numpy.array([True, False, True, False, True]), 5) == ([0, 2, 4], False) + + +def test_normalize_subscript_allow_negative(): + assert normalize_subscript(-50, 100, non_negative_only=False) == ([-50], True) + assert normalize_subscript(range(50, -10, -1), 100, non_negative_only=False) == (range(50, -10, -1), False) + assert normalize_subscript(range(-10, -50, -1), 100, non_negative_only=False) == (range(-10, -50, -1), False) + assert normalize_subscript([0,-1,2,-3,4,-5,6,-7,8], 50, non_negative_only=False) == ([0,-1,2,-3,4,-5,6,-7,8], False) + + with pytest.raises(IndexError) as ex: + normalize_subscript([0,-1,2,-3,4,-51,6,-7,8], 50, non_negative_only=False) + assert str(ex.value).find("subscript (-51) out of range") >= 0