From 2dd48608987c33ba2b6d2d7dc6d1514867acd9f1 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 7 Dec 2022 11:32:02 -0800 Subject: [PATCH] Use doctest for testing python example docstrings (#1073) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Similar to https://github.com/rapidsai/cudf/pull/9815, this change uses doctest to test that the pylibraft example docstrings run without issue. This caught several errors in the example docstrings, that are also fixed in this PR: * a missing ‘device_ndarray’ import in kmeans fit when the centroids weren’t explicitly passed in * an error in the fused_l2_nn_argmin docstring where output wasn’t defined * An `AttributeError: module 'pylibraft.neighbors.ivf_pq' has no attribute 'np'` error in ivf_pq Closes https://github.com/rapidsai/raft/issues/981 Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1073 --- .../pylibraft/pylibraft/cluster/__init__.py | 4 +- python/pylibraft/pylibraft/cluster/kmeans.pyx | 83 ++++++------ .../pylibraft/pylibraft/distance/__init__.py | 2 + .../pylibraft/distance/fused_l2_nn.pyx | 20 +-- .../pylibraft/distance/pairwise_distance.pyx | 11 +- .../pylibraft/pylibraft/neighbors/__init__.py | 2 + .../pylibraft/neighbors/ivf_pq/__init__.py | 2 + .../pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 88 ++++++------- .../pylibraft/pylibraft/neighbors/refine.pyx | 70 +++++----- python/pylibraft/pylibraft/random/__init__.py | 2 + .../random/rmat_rectangular_generator.pyx | 12 +- .../pylibraft/pylibraft/test/test_doctests.py | 122 ++++++++++++++++++ 12 files changed, 271 insertions(+), 147 deletions(-) create mode 100644 python/pylibraft/pylibraft/test/test_doctests.py diff --git a/python/pylibraft/pylibraft/cluster/__init__.py b/python/pylibraft/pylibraft/cluster/__init__.py index 89a403fce2..4facc3dae2 100644 --- a/python/pylibraft/pylibraft/cluster/__init__.py +++ b/python/pylibraft/pylibraft/cluster/__init__.py @@ -13,4 +13,6 @@ # limitations under the License. # -from .kmeans import compute_new_centroids +from .kmeans import KMeansParams, cluster_cost, compute_new_centroids, fit + +__all__ = ["KMeansParams", "cluster_cost", "compute_new_centroids", "fit"] diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index ca25b45843..9097eccfa8 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -27,7 +27,7 @@ from libcpp cimport nullptr from collections import namedtuple from enum import IntEnum -from pylibraft.common import Handle, cai_wrapper +from pylibraft.common import Handle, cai_wrapper, device_ndarray from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport handle_t @@ -81,33 +81,33 @@ def compute_new_centroids(X, -------- >>> import cupy as cp - >>> + >>> from pylibraft.common import Handle >>> from pylibraft.cluster.kmeans import compute_new_centroids - >>> + >>> # A single RAFT handle can optionally be reused across >>> # pylibraft functions. >>> handle = Handle() - >>> + >>> n_samples = 5000 >>> n_features = 50 >>> n_clusters = 3 - >>> + >>> X = cp.random.random_sample((n_samples, n_features), - >>> dtype=cp.float32) - >>> + ... dtype=cp.float32) + >>> centroids = cp.random.random_sample((n_clusters, n_features), - >>> dtype=cp.float32) - >>> + ... dtype=cp.float32) + ... >>> labels = cp.random.randint(0, high=n_clusters, size=n_samples, - >>> dtype=cp.int32) - >>> + ... dtype=cp.int32) + >>> new_centroids = cp.empty((n_clusters, n_features), dtype=cp.float32) - >>> + >>> compute_new_centroids( - >>> X, centroids, labels, new_centroids, handle=handle - >>> ) - >>> + ... X, centroids, labels, new_centroids, handle=handle + ... ) + >>> # pylibraft functions are often asynchronous so the >>> # handle needs to be explicitly synchronized >>> handle.sync() @@ -211,22 +211,21 @@ def cluster_cost(X, centroids, handle=None): Examples -------- - .. code-block:: python - import cupy as cp - - from pylibraft.cluster.kmeans import cluster_cost - - n_samples = 5000 - n_features = 50 - n_clusters = 3 - - X = cp.random.random_sample((n_samples, n_features), - dtype=cp.float32) + >>> import cupy as cp + >>> + >>> from pylibraft.cluster.kmeans import cluster_cost + >>> + >>> n_samples = 5000 + >>> n_features = 50 + >>> n_clusters = 3 + >>> + >>> X = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) - centroids = cp.random.random_sample((n_clusters, n_features), - dtype=cp.float32) + >>> centroids = cp.random.random_sample((n_clusters, n_features), + ... dtype=cp.float32) - inertia = cluster_cost(X, centroids) + >>> inertia = cluster_cost(X, centroids) """ x_cai = X.__cuda_array_interface__ centroids_cai = centroids.__cuda_array_interface__ @@ -434,21 +433,19 @@ def fit( Examples -------- - .. code-block:: python - - import cupy as cp - - from pylibraft.cluster.kmeans import fit, KMeansParams - - n_samples = 5000 - n_features = 50 - n_clusters = 3 - - X = cp.random.random_sample((n_samples, n_features), - dtype=cp.float32) + >>> import cupy as cp + >>> + >>> from pylibraft.cluster.kmeans import fit, KMeansParams + >>> + >>> n_samples = 5000 + >>> n_features = 50 + >>> n_clusters = 3 + >>> + >>> X = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) - params = KMeansParams(n_clusters=n_clusters) - centroids, inertia, n_iter = fit(params, X) + >>> params = KMeansParams(n_clusters=n_clusters) + >>> centroids, inertia, n_iter = fit(params, X) """ cdef handle_t *h = handle.getHandle() diff --git a/python/pylibraft/pylibraft/distance/__init__.py b/python/pylibraft/pylibraft/distance/__init__.py index b251e71ba3..f059b5f3dd 100644 --- a/python/pylibraft/pylibraft/distance/__init__.py +++ b/python/pylibraft/pylibraft/distance/__init__.py @@ -15,3 +15,5 @@ from .fused_l2_nn import fused_l2_nn_argmin from .pairwise_distance import DISTANCE_TYPES, distance as pairwise_distance + +__all__ = ["fused_l2_nn_argmin", "pairwise_distance"] diff --git a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx index 81a81b2632..a21fe46fa3 100644 --- a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx +++ b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx @@ -80,15 +80,15 @@ def fused_l2_nn_argmin(X, Y, out=None, sqrt=True, handle=None): >>> n_clusters = 5 >>> n_features = 50 >>> in1 = cp.random.random_sample((n_samples, n_features), - >>> dtype=cp.float32) + ... dtype=cp.float32) >>> in2 = cp.random.random_sample((n_clusters, n_features), - >>> dtype=cp.float32) + ... dtype=cp.float32) >>> # A single RAFT handle can optionally be reused across >>> # pylibraft functions. >>> handle = Handle() - >>> ... - >>> output = fused_l2_nn_argmin(in1, in2, output, handle=handle) - >>> ... + + >>> output = fused_l2_nn_argmin(in1, in2, handle=handle) + >>> # pylibraft functions are often asynchronous so the >>> # handle needs to be explicitly synchronized >>> handle.sync() @@ -103,20 +103,20 @@ def fused_l2_nn_argmin(X, Y, out=None, sqrt=True, handle=None): >>> n_clusters = 5 >>> n_features = 50 >>> in1 = cp.random.random_sample((n_samples, n_features), - >>> dtype=cp.float32) + ... dtype=cp.float32) >>> in2 = cp.random.random_sample((n_clusters, n_features), - >>> dtype=cp.float32) + ... dtype=cp.float32) >>> output = cp.empty((n_samples, 1), dtype=cp.int32) >>> # A single RAFT handle can optionally be reused across >>> # pylibraft functions. >>> handle = Handle() - >>> ... + >>> fused_l2_nn_argmin(in1, in2, out=output, handle=handle) - >>> ... + array(...) + >>> # pylibraft functions are often asynchronous so the >>> # handle needs to be explicitly synchronized >>> handle.sync() - """ x_cai = cai_wrapper(X) diff --git a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx index 450444f953..6f7a135951 100644 --- a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx +++ b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx @@ -124,9 +124,9 @@ def distance(X, Y, out=None, metric="euclidean", p=2.0, handle=None): >>> n_samples = 5000 >>> n_features = 50 >>> in1 = cp.random.random_sample((n_samples, n_features), - >>> dtype=cp.float32) + ... dtype=cp.float32) >>> in2 = cp.random.random_sample((n_samples, n_features), - >>> dtype=cp.float32) + ... dtype=cp.float32) A single RAFT handle can optionally be reused across pylibraft functions. @@ -147,9 +147,9 @@ def distance(X, Y, out=None, metric="euclidean", p=2.0, handle=None): >>> n_samples = 5000 >>> n_features = 50 >>> in1 = cp.random.random_sample((n_samples, n_features), - >>> dtype=cp.float32) + ... dtype=cp.float32) >>> in2 = cp.random.random_sample((n_samples, n_features), - >>> dtype=cp.float32) + ... dtype=cp.float32) >>> output = cp.empty((n_samples, n_samples), dtype=cp.float32) A single RAFT handle can optionally be reused across @@ -158,7 +158,8 @@ def distance(X, Y, out=None, metric="euclidean", p=2.0, handle=None): >>> >>> handle = Handle() >>> pairwise_distance(in1, in2, out=output, - >>> metric="euclidean", handle=handle) + ... metric="euclidean", handle=handle) + array(...) pylibraft functions are often asynchronous so the handle needs to be explicitly synchronized diff --git a/python/pylibraft/pylibraft/neighbors/__init__.py b/python/pylibraft/pylibraft/neighbors/__init__.py index 2f5104bd6b..dd8cdd8445 100644 --- a/python/pylibraft/pylibraft/neighbors/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/__init__.py @@ -13,3 +13,5 @@ # limitations under the License. # from .refine import refine + +__all__ = ["refine"] diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/__init__.py b/python/pylibraft/pylibraft/neighbors/ivf_pq/__init__.py index 8a231b2c8c..559eb21fdf 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/__init__.py @@ -14,3 +14,5 @@ # from .ivf_pq import Index, IndexParams, SearchParams, build, extend, search + +__all__ = ["Index", "IndexParams", "SearchParams", "build", "extend", "search"] diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index d98d0432da..fdc8d1755c 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -321,37 +321,36 @@ def build(IndexParams index_params, dataset, handle=None): -------- >>> import cupy as cp - >>> + >>> from pylibraft.common import Handle >>> from pylibraft.neighbors import ivf_pq - >>> + >>> n_samples = 50000 >>> n_features = 50 >>> n_queries = 1000 - >>> + >>> dataset = cp.random.random_sample((n_samples, n_features), - >>> dtype=cp.float32) + ... dtype=cp.float32) >>> handle = Handle() >>> index_params = ivf_pq.IndexParams( - >>> n_lists=1024, - >>> metric="l2_expanded", - >>> pq_dim=10) + ... n_lists=1024, + ... metric="l2_expanded", + ... pq_dim=10) >>> index = ivf_pq.build(index_params, dataset, handle=handle) - >>> + >>> # Search using the built index >>> queries = cp.random.random_sample((n_queries, n_features), - >>> dtype=cp.float32) + ... dtype=cp.float32) >>> k = 10 >>> distances, neighbors = ivf_pq.search(ivf_pq.SearchParams(), index, - >>> queries, k, handle=handle) - >>> + ... queries, k, handle=handle) + >>> distances = cp.asarray(distances) >>> neighbors = cp.asarray(neighbors) - >>> + >>> # pylibraft functions are often asynchronous so the >>> # handle needs to be explicitly synchronized >>> handle.sync() - """ dataset_cai = cai_wrapper(dataset) dataset_dt = dataset_cai.dtype @@ -425,37 +424,37 @@ def extend(Index index, new_vectors, new_indices, handle=None): -------- >>> import cupy as cp - >>> + >>> from pylibraft.common import Handle >>> from pylibraft.neighbors import ivf_pq - >>> + >>> n_samples = 50000 >>> n_features = 50 >>> n_queries = 1000 - >>> + >>> dataset = cp.random.random_sample((n_samples, n_features), - >>> dtype=cp.float32) + ... dtype=cp.float32) >>> handle = Handle() >>> index = ivf_pq.build(ivf_pq.IndexParams(), dataset, handle=handle) - >>> + >>> n_rows = 100 >>> more_data = cp.random.random_sample((n_rows, n_features), - >>> dtype=cp.float32) + ... dtype=cp.float32) >>> indices = index.size + cp.arange(n_rows, dtype=cp.uint64) >>> index = ivf_pq.extend(index, more_data, indices) - >>> + >>> # Search using the built index >>> queries = cp.random.random_sample((n_queries, n_features), - >>> dtype=cp.float32) + ... dtype=cp.float32) >>> k = 10 >>> distances, neighbors = ivf_pq.search(ivf_pq.SearchParams(), - >>> index, queries, - >>> k, handle=handle) - >>> + ... index, queries, + ... k, handle=handle) + >>> # pylibraft functions are often asynchronous so the >>> # handle needs to be explicitly synchronized >>> handle.sync() - >>> + >>> distances = cp.asarray(distances) >>> neighbors = cp.asarray(neighbors) """ @@ -602,46 +601,47 @@ def search(SearchParams search_params, Examples -------- >>> import cupy as cp - >>> + >>> from pylibraft.common import Handle >>> from pylibraft.neighbors import ivf_pq - >>> + >>> n_samples = 50000 >>> n_features = 50 >>> n_queries = 1000 >>> dataset = cp.random.random_sample((n_samples, n_features), - >>> dtype=cp.float32) - >>> + ... dtype=cp.float32) + >>> # Build index >>> handle = Handle() >>> index = ivf_pq.build(ivf_pq.IndexParams(), dataset, handle=handle) - >>> + >>> # Search using the built index >>> queries = cp.random.random_sample((n_queries, n_features), - >>> dtype=cp.float32) + ... dtype=cp.float32) >>> k = 10 >>> search_params = ivf_pq.SearchParams( - >>> n_probes=20, - >>> lut_dtype=ivf_pq.np.float16, - >>> internal_distance_dtype=ivf_pq.np.float32 - >>> ) - >>> + ... n_probes=20, + ... lut_dtype=cp.float16, + ... internal_distance_dtype=cp.float32 + ... ) + >>> # Using a pooling allocator reduces overhead of temporary array >>> # creation during search. This is useful if multiple searches >>> # are performad with same query size. + >>> import rmm >>> mr = rmm.mr.PoolMemoryResource( - >>> rmm.mr.CudaMemoryResource(), - >>> initial_pool_size=2**29, - >>> maximum_pool_size=2**31 - >>> ) + ... rmm.mr.CudaMemoryResource(), + ... initial_pool_size=2**29, + ... maximum_pool_size=2**31 + ... ) >>> distances, neighbors = ivf_pq.search(search_params, index, queries, - >>> k, memory_resource=mr, - >>> handle=handle) - >>> + ... k, memory_resource=mr, + ... handle=handle) + >>> # pylibraft functions are often asynchronous so the >>> # handle needs to be explicitly synchronized >>> handle.sync() - >>> + >>> neighbors = cp.asarray(neighbors) >>> distances = cp.asarray(distances) """ diff --git a/python/pylibraft/pylibraft/neighbors/refine.pyx b/python/pylibraft/pylibraft/neighbors/refine.pyx index 206fe15dfb..ca328c1cd5 100644 --- a/python/pylibraft/pylibraft/neighbors/refine.pyx +++ b/python/pylibraft/pylibraft/neighbors/refine.pyx @@ -253,44 +253,38 @@ def refine(dataset, queries, candidates, k=None, indices=None, distances=None, Examples -------- - .. code-block:: python - - import cupy as cp - - from pylibraft.common import Handle - from pylibraft.neighbors import ivf_pq, refine - - n_samples = 50000 - n_features = 50 - n_queries = 1000 - - dataset = cp.random.random_sample((n_samples, n_features), - dtype=cp.float32) - handle = Handle() - index_params = ivf_pq.IndexParams( - n_lists=1024, - metric="l2_expanded", - pq_dim=10) - index = ivf_pq.build(index_params, dataset, handle=handle) - - # Search using the built index - queries = cp.random.random_sample((n_queries, n_features), - dtype=cp.float32) - k = 40 - _, candidates = ivf_pq.search(ivf_pq.SearchParams(), index, - queries, k, handle=handle) - - k = 10 - distances, neighbors = refine(dataset, queries, candidates, k, - handle=handle) - distances = cp.asarray(distances) - neighbors = cp.asarray(neighbors) - - - # pylibraft functions are often asynchronous so the - # handle needs to be explicitly synchronized - handle.sync() - + >>> import cupy as cp + + >>> from pylibraft.common import Handle + >>> from pylibraft.neighbors import ivf_pq, refine + + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> handle = Handle() + >>> index_params = ivf_pq.IndexParams(n_lists=1024, metric="l2_expanded", + ... pq_dim=10) + >>> index = ivf_pq.build(index_params, dataset, handle=handle) + + >>> # Search using the built index + >>> queries = cp.random.random_sample((n_queries, n_features), + ... dtype=cp.float32) + >>> k = 40 + >>> _, candidates = ivf_pq.search(ivf_pq.SearchParams(), index, + ... queries, k, handle=handle) + + >>> k = 10 + >>> distances, neighbors = refine(dataset, queries, candidates, k, + ... handle=handle) + >>> distances = cp.asarray(distances) + >>> neighbors = cp.asarray(neighbors) + + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() """ if handle is None: diff --git a/python/pylibraft/pylibraft/random/__init__.py b/python/pylibraft/pylibraft/random/__init__.py index c34e4e6bdb..1c47a6eaac 100644 --- a/python/pylibraft/pylibraft/random/__init__.py +++ b/python/pylibraft/pylibraft/random/__init__.py @@ -14,3 +14,5 @@ # from .rmat_rectangular_generator import rmat + +__all__ = ["rmat"] diff --git a/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx b/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx index 17c574bea5..56d6ced468 100644 --- a/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx +++ b/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx @@ -98,24 +98,24 @@ def rmat(out, theta, r_scale, c_scale, seed=12345, handle=None): -------- >>> import cupy as cp - >>> + >>> from pylibraft.common import Handle >>> from pylibraft.random import rmat - >>> + >>> n_edges = 5000 >>> r_scale = 16 >>> c_scale = 14 >>> theta_len = max(r_scale, c_scale) * 4 - >>> + >>> out = cp.empty((n_edges, 2), dtype=cp.int32) >>> theta = cp.random.random_sample(theta_len, dtype=cp.float32) - >>> + >>> # A single RAFT handle can optionally be reused across >>> # pylibraft functions. >>> handle = Handle() - >>> ... + >>> rmat(out, theta, r_scale, c_scale, handle=handle) - >>> ... + >>> # pylibraft functions are often asynchronous so the >>> # handle needs to be explicitly synchronized >>> handle.sync() diff --git a/python/pylibraft/pylibraft/test/test_doctests.py b/python/pylibraft/pylibraft/test/test_doctests.py new file mode 100644 index 0000000000..29260c7128 --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_doctests.py @@ -0,0 +1,122 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import doctest +import inspect +import io + +import pytest + +import pylibraft.cluster +import pylibraft.distance +import pylibraft.neighbors +import pylibraft.random + +# Code adapted from https://github.com/rapidsai/cudf/blob/branch-23.02/python/cudf/cudf/tests/test_doctests.py # noqa + + +def _name_in_all(parent, name): + return name in getattr(parent, "__all__", []) + + +def _is_public_name(parent, name): + return not name.startswith("_") + + +def _find_doctests_in_obj(obj, finder=None, criteria=None): + """Find all doctests in an object. + + Parameters + ---------- + obj : module or class + The object to search for docstring examples. + finder : doctest.DocTestFinder, optional + The DocTestFinder object to use. If not provided, a DocTestFinder is + constructed. + criteria : callable, optional + Callable indicating whether to recurse over members of the provided + object. If not provided, names not defined in the object's ``__all__`` + property are ignored. + + Yields + ------ + doctest.DocTest + The next doctest found in the object. + """ + if finder is None: + finder = doctest.DocTestFinder() + if criteria is None: + criteria = _name_in_all + for docstring in finder.find(obj): + if docstring.examples: + yield docstring + for name, member in inspect.getmembers(obj): + # Only recurse over members matching the criteria + if not criteria(obj, name): + continue + # Recurse over the public API of modules (objects defined in the + # module's __all__) + if inspect.ismodule(member): + yield from _find_doctests_in_obj( + member, finder, criteria=_name_in_all + ) + # Recurse over the public API of classes (attributes not prefixed with + # an underscore) + if inspect.isclass(member): + yield from _find_doctests_in_obj( + member, finder, criteria=_is_public_name + ) + + # doctest finder seems to dislike cython functions, since + # `inspect.isfunction` doesn't return true for them. hack around this + if callable(member) and not inspect.isfunction(member): + for docstring in finder.find(member): + if docstring.examples: + yield docstring + + +# since the root pylibraft module doesn't import submodules (or define an +# __all__) we are explicitly adding all the submodules we want to run +# doctests for here +DOC_STRINGS = list(_find_doctests_in_obj(pylibraft.cluster)) +DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.distance)) +DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors)) +DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors.ivf_pq)) +DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.random)) + + +@pytest.mark.parametrize( + "docstring", + DOC_STRINGS, + ids=lambda docstring: docstring.name, +) +def test_docstring(docstring): + # We ignore differences in whitespace in the doctest output, and enable + # the use of an ellipsis "..." to match any string in the doctest + # output. An ellipsis is useful for, e.g., memory addresses or + # imprecise floating point values. + optionflags = doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE + runner = doctest.DocTestRunner(optionflags=optionflags) + + # Capture stdout and include failing outputs in the traceback. + doctest_stdout = io.StringIO() + with contextlib.redirect_stdout(doctest_stdout): + runner.run(docstring) + results = runner.summarize() + assert not results.failed, ( + f"{results.failed} of {results.attempted} doctests failed for " + f"{docstring.name}:\n{doctest_stdout.getvalue()}" + )