Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use doctest for testing python example docstrings #1073

Merged
merged 2 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/pylibraft/pylibraft/cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@
# limitations under the License.
#

from .kmeans import compute_new_centroids
from .kmeans import KMeansParams, cluster_cost, compute_new_centroids, fit

__all__ = ["KMeansParams", "cluster_cost", "compute_new_centroids", "fit"]
83 changes: 40 additions & 43 deletions python/pylibraft/pylibraft/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ from libcpp cimport nullptr
from collections import namedtuple
from enum import IntEnum

from pylibraft.common import Handle, cai_wrapper
from pylibraft.common import Handle, cai_wrapper, device_ndarray
from pylibraft.common.handle import auto_sync_handle

from pylibraft.common.handle cimport handle_t
Expand Down Expand Up @@ -81,33 +81,33 @@ def compute_new_centroids(X,
--------

>>> import cupy as cp
>>>

>>> from pylibraft.common import Handle
>>> from pylibraft.cluster.kmeans import compute_new_centroids
>>>

>>> # A single RAFT handle can optionally be reused across
>>> # pylibraft functions.
>>> handle = Handle()
>>>

>>> n_samples = 5000
>>> n_features = 50
>>> n_clusters = 3
>>>

>>> X = cp.random.random_sample((n_samples, n_features),
>>> dtype=cp.float32)
>>>
... dtype=cp.float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this is good to know. So this is how we represent continuation of previous line in the pydocs.


>>> centroids = cp.random.random_sample((n_clusters, n_features),
>>> dtype=cp.float32)
>>>
... dtype=cp.float32)
...
>>> labels = cp.random.randint(0, high=n_clusters, size=n_samples,
>>> dtype=cp.int32)
>>>
... dtype=cp.int32)

>>> new_centroids = cp.empty((n_clusters, n_features), dtype=cp.float32)
>>>

>>> compute_new_centroids(
>>> X, centroids, labels, new_centroids, handle=handle
>>> )
>>>
... X, centroids, labels, new_centroids, handle=handle
... )

>>> # pylibraft functions are often asynchronous so the
>>> # handle needs to be explicitly synchronized
>>> handle.sync()
Expand Down Expand Up @@ -211,22 +211,21 @@ def cluster_cost(X, centroids, handle=None):
Examples
--------

.. code-block:: python
import cupy as cp

from pylibraft.cluster.kmeans import cluster_cost

n_samples = 5000
n_features = 50
n_clusters = 3

X = cp.random.random_sample((n_samples, n_features),
dtype=cp.float32)
>>> import cupy as cp
>>>
>>> from pylibraft.cluster.kmeans import cluster_cost
>>>
>>> n_samples = 5000
>>> n_features = 50
>>> n_clusters = 3
>>>
>>> X = cp.random.random_sample((n_samples, n_features),
... dtype=cp.float32)

centroids = cp.random.random_sample((n_clusters, n_features),
dtype=cp.float32)
>>> centroids = cp.random.random_sample((n_clusters, n_features),
... dtype=cp.float32)

inertia = cluster_cost(X, centroids)
>>> inertia = cluster_cost(X, centroids)
"""
x_cai = X.__cuda_array_interface__
centroids_cai = centroids.__cuda_array_interface__
Expand Down Expand Up @@ -434,21 +433,19 @@ def fit(
Examples
--------

.. code-block:: python

import cupy as cp

from pylibraft.cluster.kmeans import fit, KMeansParams

n_samples = 5000
n_features = 50
n_clusters = 3

X = cp.random.random_sample((n_samples, n_features),
dtype=cp.float32)
>>> import cupy as cp
>>>
>>> from pylibraft.cluster.kmeans import fit, KMeansParams
>>>
>>> n_samples = 5000
>>> n_features = 50
>>> n_clusters = 3
>>>
>>> X = cp.random.random_sample((n_samples, n_features),
... dtype=cp.float32)

params = KMeansParams(n_clusters=n_clusters)
centroids, inertia, n_iter = fit(params, X)
>>> params = KMeansParams(n_clusters=n_clusters)
>>> centroids, inertia, n_iter = fit(params, X)
"""
cdef handle_t *h = <handle_t*><size_t>handle.getHandle()

Expand Down
2 changes: 2 additions & 0 deletions python/pylibraft/pylibraft/distance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@

from .fused_l2_nn import fused_l2_nn_argmin
from .pairwise_distance import DISTANCE_TYPES, distance as pairwise_distance

__all__ = ["fused_l2_nn_argmin", "pairwise_distance"]
20 changes: 10 additions & 10 deletions python/pylibraft/pylibraft/distance/fused_l2_nn.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ def fused_l2_nn_argmin(X, Y, out=None, sqrt=True, handle=None):
>>> n_clusters = 5
>>> n_features = 50
>>> in1 = cp.random.random_sample((n_samples, n_features),
>>> dtype=cp.float32)
... dtype=cp.float32)
>>> in2 = cp.random.random_sample((n_clusters, n_features),
>>> dtype=cp.float32)
... dtype=cp.float32)
>>> # A single RAFT handle can optionally be reused across
>>> # pylibraft functions.
>>> handle = Handle()
>>> ...
>>> output = fused_l2_nn_argmin(in1, in2, output, handle=handle)
>>> ...

>>> output = fused_l2_nn_argmin(in1, in2, handle=handle)

>>> # pylibraft functions are often asynchronous so the
>>> # handle needs to be explicitly synchronized
>>> handle.sync()
Expand All @@ -103,20 +103,20 @@ def fused_l2_nn_argmin(X, Y, out=None, sqrt=True, handle=None):
>>> n_clusters = 5
>>> n_features = 50
>>> in1 = cp.random.random_sample((n_samples, n_features),
>>> dtype=cp.float32)
... dtype=cp.float32)
>>> in2 = cp.random.random_sample((n_clusters, n_features),
>>> dtype=cp.float32)
... dtype=cp.float32)
>>> output = cp.empty((n_samples, 1), dtype=cp.int32)
>>> # A single RAFT handle can optionally be reused across
>>> # pylibraft functions.
>>> handle = Handle()
>>> ...

>>> fused_l2_nn_argmin(in1, in2, out=output, handle=handle)
>>> ...
array(...)

>>> # pylibraft functions are often asynchronous so the
>>> # handle needs to be explicitly synchronized
>>> handle.sync()

"""

x_cai = cai_wrapper(X)
Expand Down
11 changes: 6 additions & 5 deletions python/pylibraft/pylibraft/distance/pairwise_distance.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def distance(X, Y, out=None, metric="euclidean", p=2.0, handle=None):
>>> n_samples = 5000
>>> n_features = 50
>>> in1 = cp.random.random_sample((n_samples, n_features),
>>> dtype=cp.float32)
... dtype=cp.float32)
>>> in2 = cp.random.random_sample((n_samples, n_features),
>>> dtype=cp.float32)
... dtype=cp.float32)

A single RAFT handle can optionally be reused across
pylibraft functions.
Expand All @@ -147,9 +147,9 @@ def distance(X, Y, out=None, metric="euclidean", p=2.0, handle=None):
>>> n_samples = 5000
>>> n_features = 50
>>> in1 = cp.random.random_sample((n_samples, n_features),
>>> dtype=cp.float32)
... dtype=cp.float32)
>>> in2 = cp.random.random_sample((n_samples, n_features),
>>> dtype=cp.float32)
... dtype=cp.float32)
>>> output = cp.empty((n_samples, n_samples), dtype=cp.float32)

A single RAFT handle can optionally be reused across
Expand All @@ -158,7 +158,8 @@ def distance(X, Y, out=None, metric="euclidean", p=2.0, handle=None):
>>>
>>> handle = Handle()
>>> pairwise_distance(in1, in2, out=output,
>>> metric="euclidean", handle=handle)
... metric="euclidean", handle=handle)
array(...)

pylibraft functions are often asynchronous so the
handle needs to be explicitly synchronized
Expand Down
2 changes: 2 additions & 0 deletions python/pylibraft/pylibraft/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
# limitations under the License.
#
from .refine import refine

__all__ = ["refine"]
2 changes: 2 additions & 0 deletions python/pylibraft/pylibraft/neighbors/ivf_pq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@
#

from .ivf_pq import Index, IndexParams, SearchParams, build, extend, search

__all__ = ["Index", "IndexParams", "SearchParams", "build", "extend", "search"]
Loading