diff --git a/src/biocutils/match.py b/src/biocutils/match.py index e9d12b2..2881bec 100644 --- a/src/biocutils/match.py +++ b/src/biocutils/match.py @@ -9,11 +9,13 @@ def match( targets: Union[dict, Sequence], duplicate_method: DUPLICATE_METHOD = "first", dtype: Optional[numpy.ndarray] = None, + fail_missing: Optional[bool] = None, ) -> numpy.ndarray: """Find a matching value of each element of ``x`` in ``target``. Args: - x: Squence of values to match. + x: + Sequence of values to match. targets: Sequence of targets to be matched against. Alternatively, a @@ -27,7 +29,12 @@ def match( dtype: NumPy type of the output array. This should be an integer type; if missing values are expected, the type should be a signed integer. - If None, a suitable type is automatically determined. + If None, a suitable signed type is automatically determined. + + fail_missing: + Whether to raise an error if ``x`` cannot be found in ``targets``. + If ``None``, this defaults to ``True`` if ``dtype`` is an unsigned + type, otherwise it defaults to ``False``. Returns: Array of length equal to ``x``, containing the integer position of each @@ -41,10 +48,20 @@ def match( dtype = numpy.min_scalar_type(-len(targets)) # get a signed type indices = numpy.zeros(len(x), dtype=dtype) - for i, y in enumerate(x): - if y not in targets: - indices[i] = -1 - else: + if fail_missing is None: + fail_missing = numpy.issubdtype(dtype, numpy.unsignedinteger) + + # Separate loops to reduce branching in the tight inner loop. + if not fail_missing: + for i, y in enumerate(x): + if y in targets: + indices[i] = targets[y] + else: + indices[i] = -1 + else: + for i, y in enumerate(x): + if not y in targets: + raise ValueError("cannot find '" + str(y) + "' in 'targets'") indices[i] = targets[y] return indices diff --git a/tests/test_match.py b/tests/test_match.py index 3efb970..c8fd6ba 100644 --- a/tests/test_match.py +++ b/tests/test_match.py @@ -1,5 +1,6 @@ from biocutils import match, map_to_index import numpy +import pytest def test_match_simple(): @@ -39,3 +40,17 @@ def test_match_dtype(): mm = match(["A", "B", "D", "A", "C", "B"], ["D", "C", "B", "A"], dtype=numpy.dtype("uint32")) assert list(mm) == [3, 2, 0, 3, 1, 2] assert mm.dtype == numpy.dtype("uint32") + + +def test_match_fail_missing(): + x = match(["A", "E", "B", "D", "E"], ["D", "C", "B", "A"]) + assert list(x) == [3, -1, 2, 0, -1] + + with pytest.raises(ValueError, match="cannot find"): + match(["A", "E", "B", "D", "E"], ["D", "C", "B", "A"], fail_missing=True) + + with pytest.raises(ValueError, match="cannot find"): + match(["A", "E", "B", "D", "E"], ["D", "C", "B", "A"], dtype=numpy.uint32) + + x = match(["A", "C", "B", "D", "C"], ["D", "C", "B", "A"], fail_missing=True) + assert list(x) == [3, 1, 2, 0, 1]