Skip to content

Commit

Permalink
Add comparer overload to sequence_equal()
Browse files Browse the repository at this point in the history
  • Loading branch information
cleoold committed May 9, 2021
1 parent 4d2357a commit 77a4a83
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 11 deletions.
21 changes: 21 additions & 0 deletions doc/api/types_linq.enumerable.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,27 @@ Example
----

instancemethod ``sequence_equal[TOther](second, __comparer)``
---------------------------------------------------------------

Parameters
- `second` (``Iterable[TOther]``)
- `__comparer` (``Callable[[TSource_co, TOther], bool]``)

Returns
- ``bool``

Determines whether two sequences are equal using a comparer that returns True if two values
are equal, on each element.

Example
>>> ints = [1, 3, 5, 7, 9]
>>> strs = ['1', '3', '5', '7', '9']
>>> Enumerable(ints).sequence_equal(strs, lambda x, y: str(x) == y)
True

----

instancemethod ``single()``
-----------------------------

Expand Down
2 changes: 1 addition & 1 deletion doc/to-start/installing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ execute the following commands:
$ make html
Note to generate api files, one must have Python version 3.9 or above. The api rst files
are commited in the repository.
are commited to the repository.
28 changes: 20 additions & 8 deletions tests/test_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,53 +934,65 @@ def test_selectmany2_overload2(self):


class TestSequenceEqualMethod:
def test_sequence_equal(self):
def test_overload1(self):
lst = [['a'], ['x'], ['y'], [16], [17]]
en1 = Enumerable(lst)
en2 = Enumerable((['a'], ['x'], ['y'], [16], [17]))
assert en1.sequence_equal(en2) is True
assert en2.sequence_equal(en1) is True
assert en2.sequence_equal(lst) is True

def test_1_elem(self):
def test_overload1_1_elem(self):
en1 = Enumerable(['a'])
en2 = Enumerable('a')
assert en1.sequence_equal(en2)
assert en2.sequence_equal(en1)

def test_both_empty(self):
def test_overload1_both_empty(self):
assert Enumerable.empty().sequence_equal([]) is True

def test_one_empty(self):
def test_overload1_one_empty(self):
en1 = Enumerable.empty().cast(str)
en2 = Enumerable(['a', 'x'])
assert en1.sequence_equal(en2) is False
assert en2.sequence_equal(en1) is False

def test_one_more(self):
def test_overload1_one_more(self):
en1 = Enumerable(['a', 'x', 'y'])
en2 = Enumerable(['a', 'x', 'y', 't'])
assert en1.sequence_equal(en2) is False
assert en2.sequence_equal(en1) is False

def test_first_off(self):
def test_overload1_first_off(self):
en1 = Enumerable(['a', 'x', 'y'])
en2 = Enumerable(['b', 'x', 'y'])
assert en1.sequence_equal(en2) is False
assert en2.sequence_equal(en1) is False

def test_last_off(self):
def test_overload1_last_off(self):
en1 = Enumerable(['a', 'x', 'y'])
en2 = Enumerable(['a', 'x', 'z'])
assert en1.sequence_equal(en2) is False
assert en2.sequence_equal(en1) is False

def test_middle_off(self):
def test_overload1_middle_off(self):
en1 = Enumerable(['a', 'x', 'y', 'z', 'u'])
en2 = Enumerable(['b', 'x', 'k', 'z', 'u'])
assert en1.sequence_equal(en2) is False
assert en2.sequence_equal(en1) is False

def test_overload2_yes(self):
ints = [1, 3, 5, 7, 9]
en = Enumerable(ints)
strs = ['1', '3', '5', '7', '9']
assert en.sequence_equal(strs, lambda x, y: str(x) == y) is True

def test_overload2_no(self):
ints = [1, 3, 5, 7, 9]
en = Enumerable(ints)
strs = ['1', '3', '5', '7', 'z']
assert en.sequence_equal(strs, lambda x, y: str(x) == y) is False


class TestSingleMethod:
def test_single_overload1_yes(self):
Expand Down
12 changes: 10 additions & 2 deletions types_linq/enumerable.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,15 @@ def inner():
yield result_selector(elem, sub)
return Enumerable(inner)

def sequence_equal(self, second: Iterable[TSource_co]) -> bool:
def sequence_equal(self,
second: Iterable[TSource_co],
*args: Callable[..., bool],
) -> bool:
if len(args) == 0:
comparer = lambda x, y: x == y
else: # len(args) == 1
comparer = args[0]

me, she = iter(self), iter(second)
while True:
try:
Expand All @@ -628,7 +636,7 @@ def sequence_equal(self, second: Iterable[TSource_co]) -> bool:
rhs = next(she)
except StopIteration:
return False
if not (lhs == rhs):
if not comparer(lhs, rhs):
return False

def _find_single(self, res):
Expand Down
17 changes: 17 additions & 0 deletions types_linq/enumerable.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,7 @@ class Enumerable(Sequence[TSource_co], Generic[TSource_co]):
Table 2: Chicken
'''

@overload
def sequence_equal(self, second: Iterable[TSource_co]) -> bool:
'''
Determines whether two sequences are equal using `==` on each element.
Expand All @@ -1187,6 +1188,22 @@ class Enumerable(Sequence[TSource_co], Generic[TSource_co]):
True
'''

@overload
def sequence_equal(self,
second: Iterable[TOther],
__comparer: Callable[[TSource_co, TOther], bool],
) -> bool:
'''
Determines whether two sequences are equal using a comparer that returns True if two values
are equal, on each element.
Example
>>> ints = [1, 3, 5, 7, 9]
>>> strs = ['1', '3', '5', '7', '9']
>>> Enumerable(ints).sequence_equal(strs, lambda x, y: str(x) == y)
True
'''

@overload
def single(self) -> TSource_co:
'''
Expand Down

0 comments on commit 77a4a83

Please sign in to comment.