From 693c169ca6ee26b95984c1d56d31a352b203ffc6 Mon Sep 17 00:00:00 2001 From: LTLA Date: Tue, 17 Dec 2024 21:15:09 -0800 Subject: [PATCH] Added fail_missing= to match() to fail on missing entries. This defaults to True if the return type is unsigned, as otherwise there is no way to reliably represent missing values in the returned array. --- src/biocutils/factorize.py | 7 ++++++- src/biocutils/match.py | 29 +++++++++++++++++++++++------ tests/test_match.py | 15 +++++++++++++++ 3 files changed, 44 insertions(+), 7 deletions(-) diff --git a/src/biocutils/factorize.py b/src/biocutils/factorize.py index 0d538d6..904d24d 100644 --- a/src/biocutils/factorize.py +++ b/src/biocutils/factorize.py @@ -11,6 +11,7 @@ def factorize( levels: Optional[Sequence] = None, sort_levels: bool = False, dtype: Optional[numpy.dtype] = None, + fail_missing: Optional[bool] = None, ) -> Tuple[list, numpy.ndarray]: """Convert a sequence of hashable values into a factor. @@ -32,6 +33,10 @@ def factorize( NumPy type of the array of indices, see :py:func:`~biocutils.match.match` for details. + fail_missing: + Whether to raise an error upon encountering missing levels in + ``x``, see :py:func:`~biocutils.match.match` for details. + Returns: Tuple where the first element is a list of unique levels and the second element in a NumPy array containing integer codes, i.e., indices into @@ -51,5 +56,5 @@ def factorize( if sort_levels: levels.sort() - codes = match(x, levels, dtype=dtype) + codes = match(x, levels, dtype=dtype, fail_missing=fail_missing) return levels, codes diff --git a/src/biocutils/match.py b/src/biocutils/match.py index e9d12b2..2881bec 100644 --- a/src/biocutils/match.py +++ b/src/biocutils/match.py @@ -9,11 +9,13 @@ def match( targets: Union[dict, Sequence], duplicate_method: DUPLICATE_METHOD = "first", dtype: Optional[numpy.ndarray] = None, + fail_missing: Optional[bool] = None, ) -> numpy.ndarray: """Find a matching value of each element of ``x`` in ``target``. Args: - x: Squence of values to match. + x: + Sequence of values to match. targets: Sequence of targets to be matched against. Alternatively, a @@ -27,7 +29,12 @@ def match( dtype: NumPy type of the output array. This should be an integer type; if missing values are expected, the type should be a signed integer. - If None, a suitable type is automatically determined. + If None, a suitable signed type is automatically determined. + + fail_missing: + Whether to raise an error if ``x`` cannot be found in ``targets``. + If ``None``, this defaults to ``True`` if ``dtype`` is an unsigned + type, otherwise it defaults to ``False``. Returns: Array of length equal to ``x``, containing the integer position of each @@ -41,10 +48,20 @@ def match( dtype = numpy.min_scalar_type(-len(targets)) # get a signed type indices = numpy.zeros(len(x), dtype=dtype) - for i, y in enumerate(x): - if y not in targets: - indices[i] = -1 - else: + if fail_missing is None: + fail_missing = numpy.issubdtype(dtype, numpy.unsignedinteger) + + # Separate loops to reduce branching in the tight inner loop. + if not fail_missing: + for i, y in enumerate(x): + if y in targets: + indices[i] = targets[y] + else: + indices[i] = -1 + else: + for i, y in enumerate(x): + if not y in targets: + raise ValueError("cannot find '" + str(y) + "' in 'targets'") indices[i] = targets[y] return indices diff --git a/tests/test_match.py b/tests/test_match.py index 3efb970..c8fd6ba 100644 --- a/tests/test_match.py +++ b/tests/test_match.py @@ -1,5 +1,6 @@ from biocutils import match, map_to_index import numpy +import pytest def test_match_simple(): @@ -39,3 +40,17 @@ def test_match_dtype(): mm = match(["A", "B", "D", "A", "C", "B"], ["D", "C", "B", "A"], dtype=numpy.dtype("uint32")) assert list(mm) == [3, 2, 0, 3, 1, 2] assert mm.dtype == numpy.dtype("uint32") + + +def test_match_fail_missing(): + x = match(["A", "E", "B", "D", "E"], ["D", "C", "B", "A"]) + assert list(x) == [3, -1, 2, 0, -1] + + with pytest.raises(ValueError, match="cannot find"): + match(["A", "E", "B", "D", "E"], ["D", "C", "B", "A"], fail_missing=True) + + with pytest.raises(ValueError, match="cannot find"): + match(["A", "E", "B", "D", "E"], ["D", "C", "B", "A"], dtype=numpy.uint32) + + x = match(["A", "C", "B", "D", "C"], ["D", "C", "B", "A"], fail_missing=True) + assert list(x) == [3, 1, 2, 0, 1]