diff --git a/pipeline/__init__.py b/pipeline/__init__.py index e69de29b..dbc5a8e3 100644 --- a/pipeline/__init__.py +++ b/pipeline/__init__.py @@ -0,0 +1,5 @@ +import pytest + +# Register any modules that require pytest assertion rewriting +# https://docs.pytest.org/en/stable/how-to/writing_plugins.html#assertion-rewriting +pytest.register_assert_rewrite("pipeline.testing.assertions") diff --git a/pipeline/testing/__init__.py b/pipeline/testing/__init__.py new file mode 100644 index 00000000..f38a67c7 --- /dev/null +++ b/pipeline/testing/__init__.py @@ -0,0 +1 @@ +from pipeline.testing.assertions import * diff --git a/pipeline/testing/assertions.py b/pipeline/testing/assertions.py new file mode 100644 index 00000000..3bb3ab8a --- /dev/null +++ b/pipeline/testing/assertions.py @@ -0,0 +1,72 @@ +from typing import Any, List + +from pipeline.testing.helpers import get_first_mismatched_pair + + +def get_assertion_message(actual: Any, expected: Any) -> str: + """ + Compare two values and return their assertion error message, or an empty + string if no message. + + When pytest assertion rewriting is enabled, the assertion error may provide a helpful description of the difference between the two values: https://docs.pytest.org/en/stable/how-to/writing_plugins.html#assertion-rewriting + + This method can help capture that description without raising an error, so + that other testing methods can add more information to the message. + """ + message = "" + try: + assert actual == expected + except AssertionError as error: + if error: + message = str(error) + return message + + +def get_assertion_message_first_mismatch(actual: List, expected: List) -> str: + index, pair = get_first_mismatched_pair(actual, expected) + actual_item, expected_item = pair + diff = get_assertion_message(actual_item, expected_item) + message = f"First pair of mismatched items at index={index}:\n{diff}\n" + return message + + +def equal_lists(actual: List, expected: List) -> bool: + """ + Compares two lists, adding an additional explanation to the assertion error. + + When lists are the same size, but have mismatching items, this assertion + displays assertion error for the first pair of items that do not match. + + This method is recommended for use when comparing lists of dataclasses, to + get a simpler diff of the items that do not match, rather than displaying + the entirety of each list. + """ + if actual == expected: + # Case 1: Lists match exactly + return True + + explanation = "" + diff = "" + footnote = "" + + a_len = len(actual) + e_len = len(expected) + if a_len != e_len: + # Case 2: Lists have different sizes + explanation = f"`actual` differs in size: {a_len}, should be {e_len}." + diff = get_assertion_message(actual, expected) + "\n" + footnote = "Above: diff between lists." + elif sorted(actual) == sorted(expected): + # Case 3: Lists have the same size and same items, but different orders + explanation = "`actual` has the same items, but in a different order." + diff = get_assertion_message_first_mismatch(actual, expected) + footnote = "Above: diff between first pair of mismatched items." + else: + # Case 4: Lists have the same size, but different items + explanation = "`actual` has the correct length, but different items." + diff = get_assertion_message_first_mismatch(actual, expected) + footnote = "Above: diff between first pair of mismatched items." + + message = "\n".join([diff, explanation, footnote]) + raise AssertionError(message) + return False diff --git a/pipeline/testing/assertions_test.py b/pipeline/testing/assertions_test.py new file mode 100644 index 00000000..1568c0b0 --- /dev/null +++ b/pipeline/testing/assertions_test.py @@ -0,0 +1,85 @@ +from dataclasses import dataclass +from unittest.mock import MagicMock, patch + +import pytest + +from pipeline.testing.assertions import equal_lists + +MOCK_PYTEST_ASSERTION_MESSAGE = "(Pytest diff goes here.)" + +patch_get_assertion_message = patch( + "pipeline.testing.assertions.get_assertion_message", + return_value=MOCK_PYTEST_ASSERTION_MESSAGE, +) + + +# TODO(vinesh): Enable ordering for all our dataclasses that are used in lists +@dataclass(order=True) +class ExampleClass: + a: str + b: int + + +@patch_get_assertion_message +def test_assert_dataclass_lists_different_items(get_assertion_message): + example_actual = [ExampleClass(a="A", b=2), ExampleClass(a="B", b=2)] + example_expected = [ExampleClass(a="A", b=2), ExampleClass(a="B", b=3)] + + with pytest.raises(AssertionError) as excinfo: + equal_lists(example_actual, example_expected) + actual = str(excinfo.value) + + expected = ( + "First pair of mismatched items at index=1:\n" + "(Pytest diff goes here.)\n\n" + "`actual` has the correct length, but different items.\n" + "Above: diff between first pair of mismatched items." + ) + assert actual == expected + assert get_assertion_message.called_once() + + +@patch_get_assertion_message +def test_assert_dataclass_lists_different_order(get_assertion_message): + example_actual = [ExampleClass(a="B", b=2), ExampleClass(a="A", b=2)] + example_expected = [ExampleClass(a="A", b=2), ExampleClass(a="B", b=2)] + + with pytest.raises(AssertionError) as excinfo: + equal_lists(example_actual, example_expected) + actual = str(excinfo.value) + + expected = ( + "First pair of mismatched items at index=0:\n" + "(Pytest diff goes here.)\n\n" + "`actual` has the same items, but in a different order.\n" + "Above: diff between first pair of mismatched items." + ) + assert actual == expected + assert get_assertion_message.called_once() + + +@patch_get_assertion_message +def test_assert_dataclass_lists_different_sizes(get_assertion_message): + example_actual = [ExampleClass(a="B", b=2)] + example_expected = [ExampleClass(a="A", b=2), ExampleClass(a="B", b=2)] + + with pytest.raises(AssertionError) as excinfo: + equal_lists(example_actual, example_expected) + actual = str(excinfo.value) + + expected = ( + "(Pytest diff goes here.)\n\n" + "`actual` differs in size: 1, should be 2.\n" + "Above: diff between lists." + ) + assert actual == expected + assert get_assertion_message.called_once() + + +@patch_get_assertion_message +def test_assert_dataclass_lists_equal(get_assertion_message): + example_actual = [ExampleClass(a="A", b=2), ExampleClass(a="B", b=2)] + example_expected = [ExampleClass(a="A", b=2), ExampleClass(a="B", b=2)] + + assert equal_lists(example_actual, example_expected) + assert get_assertion_message.not_called() diff --git a/pipeline/testing/helpers.py b/pipeline/testing/helpers.py new file mode 100644 index 00000000..b20c3e4d --- /dev/null +++ b/pipeline/testing/helpers.py @@ -0,0 +1,23 @@ +from itertools import zip_longest +from typing import List, Tuple + +IndexAndPair = Tuple[int, Tuple] + + +class EqualListsException(Exception): + pass + + +class DifferentListSizesException(Exception): + pass + + +def get_first_mismatched_pair(list_a: List, list_b: List) -> IndexAndPair: + if list_a == list_b: + raise EqualListsException("Lists are equal, there is no mismatch.") + + zipped = zip_longest(list_a, list_b, fillvalue=None) + for i, (item_a, item_b) in enumerate(zipped): + if item_a != item_b: + return i, (item_a, item_b) + return -1, (None, None) diff --git a/pipeline/testing/helpers_test.py b/pipeline/testing/helpers_test.py new file mode 100644 index 00000000..6e5159b4 --- /dev/null +++ b/pipeline/testing/helpers_test.py @@ -0,0 +1,34 @@ +import pytest + +from pipeline.testing.helpers import ( + EqualListsException, + get_first_mismatched_pair, +) + + +def test_get_first_mismatched_pair_mismatch_in_middle(): + list_a = [1, 2, 3] + list_b = [1, 4, 3] + + actual = get_first_mismatched_pair(list_a, list_b) + + expected = (1, (2, 4)) + assert actual == expected + + +def test_get_first_mismatched_pair_equal_lists(): + list_a = [1, 1] + list_b = [1, 1] + + with pytest.raises(EqualListsException): + get_first_mismatched_pair(list_a, list_b) + + +def test_get_first_mismatched_pair_lists_of_different_lengths(): + list_a = [9] + list_b = [] + + actual = get_first_mismatched_pair(list_a, list_b) + + expected = (0, (9, None)) + assert actual == expected