diff --git a/src/collektions/_mapping.py b/src/collektions/_mapping.py index d92db35..63037a9 100644 --- a/src/collektions/_mapping.py +++ b/src/collektions/_mapping.py @@ -9,46 +9,28 @@ "map_values_to", ] -from collections.abc import Iterable, Mapping, MutableMapping, MutableSequence +from collections.abc import Mapping, MutableMapping from typing import Callable from ._types import K, R, V def filter_keys( - mapping: Mapping[K, V], predicate: Callable[[K, V], bool] -) -> Iterable[K]: - """Filter ``mapping``'s key set based on ``predicate``. - - The predicate function takes both the key and value from each key/value pair - so that filtering can be done on either. - - Returns: - A collection of the key from each key/value pair for which predicate returns ``True``. + mapping: Mapping[K, V], predicate: Callable[[K], bool] +) -> Mapping[K, V]: """ - result: MutableSequence[K] = [] - for key, value in mapping.items(): - if predicate(key, value): - result.append(key) - return result + Return a new mapping of all key/value pairs from ``mapping`` where `key` satisfies ``predicate``. + """ + return {k: v for k, v in mapping.items() if predicate(k)} def filter_values( - mapping: Mapping[K, V], predicate: Callable[[K, V], bool] -) -> Iterable[V]: - """Filter ``mapping``'s value set based on ``predicate``. - - The predicate function takes both the key and value from each key/value pair - so that filtering can be done on either. - - Returns: - A collection of the value from each key/value pair for which predicate returns ``True``. + mapping: Mapping[K, V], predicate: Callable[[V], bool] +) -> Mapping[K, V]: """ - result: MutableSequence[V] = [] - for key, value in mapping.items(): - if predicate(key, value): - result.append(value) - return result + Return a new mapping of all key/value pairs from ``mapping`` where `value` satisfies ``predicate``. + """ + return {k: v for k, v in mapping.items() if predicate(v)} def map_keys(mapping: Mapping[K, V], transform: Callable[[K, V], R]) -> Mapping[R, V]: diff --git a/tests/test__mapping.py b/tests/test__mapping.py index f9c771d..89fa9ef 100644 --- a/tests/test__mapping.py +++ b/tests/test__mapping.py @@ -8,15 +8,15 @@ def test_filter_keys(): mapping = {i: str(i) for i in range(10)} - expected = [0, 2, 4, 6, 8] - actual = filter_keys(mapping, lambda k, _: k % 2 == 0) + expected = {0: "0", 2: "2", 4: "4", 6: "6", 8: "8"} + actual = filter_keys(mapping, lambda k: k % 2 == 0) assert_that(actual, equal_to(expected)) def test_filter_values(): mapping = {i: str(i) for i in range(10)} - expected = ["0", "2", "4", "6", "8"] - actual = filter_values(mapping, lambda _, v: int(v) % 2 == 0) + expected = {0: "0", 2: "2", 4: "4", 6: "6", 8: "8"} + actual = filter_values(mapping, lambda v: int(v) % 2 == 0) assert_that(actual, equal_to(expected))