Skip to content

Commit

Permalink
Use doctest for testing python example docstrings
Browse files Browse the repository at this point in the history
Similar to rapidsai/cudf#9815, this change uses doctest
to test that the pylibraft example docstrings run without issue.

This caught several errors in the example docstrings, that are also fixed in this PR:
 *  a missing ‘device_ndarray’ import in kmeans fit when the centroids weren’t explicitly passed in
 *  an error in the fused_l2_nn_argmin docstring where output wasn’t defined
 *  An `AttributeError: module 'pylibraft.neighbors.ivf_pq' has no attribute 'np'` error in ivf_pq

Closes rapidsai#981
  • Loading branch information
benfred committed Dec 7, 2022
1 parent dd49a10 commit b9c2fa9
Show file tree
Hide file tree
Showing 12 changed files with 272 additions and 147 deletions.
4 changes: 3 additions & 1 deletion python/pylibraft/pylibraft/cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@
# limitations under the License.
#

from .kmeans import compute_new_centroids
from .kmeans import KMeansParams, cluster_cost, compute_new_centroids, fit

__all__ = ["KMeansParams", "cluster_cost", "compute_new_centroids", "fit"]
83 changes: 40 additions & 43 deletions python/pylibraft/pylibraft/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ from libcpp cimport nullptr
from collections import namedtuple
from enum import IntEnum

from pylibraft.common import Handle, cai_wrapper
from pylibraft.common import Handle, cai_wrapper, device_ndarray
from pylibraft.common.handle import auto_sync_handle

from pylibraft.common.handle cimport handle_t
Expand Down Expand Up @@ -81,33 +81,33 @@ def compute_new_centroids(X,
--------
>>> import cupy as cp
>>>
>>> from pylibraft.common import Handle
>>> from pylibraft.cluster.kmeans import compute_new_centroids
>>>
>>> # A single RAFT handle can optionally be reused across
>>> # pylibraft functions.
>>> handle = Handle()
>>>
>>> n_samples = 5000
>>> n_features = 50
>>> n_clusters = 3
>>>
>>> X = cp.random.random_sample((n_samples, n_features),
>>> dtype=cp.float32)
>>>
... dtype=cp.float32)
>>> centroids = cp.random.random_sample((n_clusters, n_features),
>>> dtype=cp.float32)
>>>
... dtype=cp.float32)
...
>>> labels = cp.random.randint(0, high=n_clusters, size=n_samples,
>>> dtype=cp.int32)
>>>
... dtype=cp.int32)
>>> new_centroids = cp.empty((n_clusters, n_features), dtype=cp.float32)
>>>
>>> compute_new_centroids(
>>> X, centroids, labels, new_centroids, handle=handle
>>> )
>>>
... X, centroids, labels, new_centroids, handle=handle
... )
>>> # pylibraft functions are often asynchronous so the
>>> # handle needs to be explicitly synchronized
>>> handle.sync()
Expand Down Expand Up @@ -211,22 +211,21 @@ def cluster_cost(X, centroids, handle=None):
Examples
--------
.. code-block:: python
import cupy as cp
from pylibraft.cluster.kmeans import cluster_cost
n_samples = 5000
n_features = 50
n_clusters = 3
X = cp.random.random_sample((n_samples, n_features),
dtype=cp.float32)
>>> import cupy as cp
>>>
>>> from pylibraft.cluster.kmeans import cluster_cost
>>>
>>> n_samples = 5000
>>> n_features = 50
>>> n_clusters = 3
>>>
>>> X = cp.random.random_sample((n_samples, n_features),
... dtype=cp.float32)
centroids = cp.random.random_sample((n_clusters, n_features),
dtype=cp.float32)
>>> centroids = cp.random.random_sample((n_clusters, n_features),
... dtype=cp.float32)
inertia = cluster_cost(X, centroids)
>>> inertia = cluster_cost(X, centroids)
"""
x_cai = X.__cuda_array_interface__
centroids_cai = centroids.__cuda_array_interface__
Expand Down Expand Up @@ -434,21 +433,19 @@ def fit(
Examples
--------
.. code-block:: python
import cupy as cp
from pylibraft.cluster.kmeans import fit, KMeansParams
n_samples = 5000
n_features = 50
n_clusters = 3
X = cp.random.random_sample((n_samples, n_features),
dtype=cp.float32)
>>> import cupy as cp
>>>
>>> from pylibraft.cluster.kmeans import fit, KMeansParams
>>>
>>> n_samples = 5000
>>> n_features = 50
>>> n_clusters = 3
>>>
>>> X = cp.random.random_sample((n_samples, n_features),
... dtype=cp.float32)
params = KMeansParams(n_clusters=n_clusters)
centroids, inertia, n_iter = fit(params, X)
>>> params = KMeansParams(n_clusters=n_clusters)
>>> centroids, inertia, n_iter = fit(params, X)
"""
cdef handle_t *h = <handle_t*><size_t>handle.getHandle()

Expand Down
2 changes: 2 additions & 0 deletions python/pylibraft/pylibraft/distance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@

from .fused_l2_nn import fused_l2_nn_argmin
from .pairwise_distance import DISTANCE_TYPES, distance as pairwise_distance

__all__ = ["fused_l2_nn_argmin", "pairwise_distance"]
20 changes: 10 additions & 10 deletions python/pylibraft/pylibraft/distance/fused_l2_nn.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ def fused_l2_nn_argmin(X, Y, out=None, sqrt=True, handle=None):
>>> n_clusters = 5
>>> n_features = 50
>>> in1 = cp.random.random_sample((n_samples, n_features),
>>> dtype=cp.float32)
... dtype=cp.float32)
>>> in2 = cp.random.random_sample((n_clusters, n_features),
>>> dtype=cp.float32)
... dtype=cp.float32)
>>> # A single RAFT handle can optionally be reused across
>>> # pylibraft functions.
>>> handle = Handle()
>>> ...
>>> output = fused_l2_nn_argmin(in1, in2, output, handle=handle)
>>> ...
>>> output = fused_l2_nn_argmin(in1, in2, handle=handle)
>>> # pylibraft functions are often asynchronous so the
>>> # handle needs to be explicitly synchronized
>>> handle.sync()
Expand All @@ -103,20 +103,20 @@ def fused_l2_nn_argmin(X, Y, out=None, sqrt=True, handle=None):
>>> n_clusters = 5
>>> n_features = 50
>>> in1 = cp.random.random_sample((n_samples, n_features),
>>> dtype=cp.float32)
... dtype=cp.float32)
>>> in2 = cp.random.random_sample((n_clusters, n_features),
>>> dtype=cp.float32)
... dtype=cp.float32)
>>> output = cp.empty((n_samples, 1), dtype=cp.int32)
>>> # A single RAFT handle can optionally be reused across
>>> # pylibraft functions.
>>> handle = Handle()
>>> ...
>>> fused_l2_nn_argmin(in1, in2, out=output, handle=handle)
>>> ...
array(...)
>>> # pylibraft functions are often asynchronous so the
>>> # handle needs to be explicitly synchronized
>>> handle.sync()
"""

x_cai = cai_wrapper(X)
Expand Down
11 changes: 6 additions & 5 deletions python/pylibraft/pylibraft/distance/pairwise_distance.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def distance(X, Y, out=None, metric="euclidean", p=2.0, handle=None):
>>> n_samples = 5000
>>> n_features = 50
>>> in1 = cp.random.random_sample((n_samples, n_features),
>>> dtype=cp.float32)
... dtype=cp.float32)
>>> in2 = cp.random.random_sample((n_samples, n_features),
>>> dtype=cp.float32)
... dtype=cp.float32)
A single RAFT handle can optionally be reused across
pylibraft functions.
Expand All @@ -147,9 +147,9 @@ def distance(X, Y, out=None, metric="euclidean", p=2.0, handle=None):
>>> n_samples = 5000
>>> n_features = 50
>>> in1 = cp.random.random_sample((n_samples, n_features),
>>> dtype=cp.float32)
... dtype=cp.float32)
>>> in2 = cp.random.random_sample((n_samples, n_features),
>>> dtype=cp.float32)
... dtype=cp.float32)
>>> output = cp.empty((n_samples, n_samples), dtype=cp.float32)
A single RAFT handle can optionally be reused across
Expand All @@ -158,7 +158,8 @@ def distance(X, Y, out=None, metric="euclidean", p=2.0, handle=None):
>>>
>>> handle = Handle()
>>> pairwise_distance(in1, in2, out=output,
>>> metric="euclidean", handle=handle)
... metric="euclidean", handle=handle)
array(...)
pylibraft functions are often asynchronous so the
handle needs to be explicitly synchronized
Expand Down
2 changes: 2 additions & 0 deletions python/pylibraft/pylibraft/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
# limitations under the License.
#
from .refine import refine

__all__ = ["refine"]
2 changes: 2 additions & 0 deletions python/pylibraft/pylibraft/neighbors/ivf_pq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@
#

from .ivf_pq import Index, IndexParams, SearchParams, build, extend, search

__all__ = ["Index", "IndexParams", "SearchParams", "build", "extend", "search"]
Loading

0 comments on commit b9c2fa9

Please sign in to comment.