From a4ca4bb576c3f115d381c6d848d8dcfa5ec8ecbb Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 2 Nov 2022 14:59:03 -0700 Subject: [PATCH 1/2] Fix pylibraft docstring example code Fix the pylibraft example code given in the docstrings so that they run without exceptions. --- python/pylibraft/pylibraft/cluster/kmeans.pyx | 4 ++-- python/pylibraft/pylibraft/common/cuda.pyx | 2 +- python/pylibraft/pylibraft/distance/fused_l2_nn.pyx | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index c2d445f970..e88813cfbb 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -107,7 +107,7 @@ def compute_new_centroids(X, import cupy as cp from pylibraft.common import Handle - from pylibaft.cluster.kmeans import update_centroids + from pylibraft.cluster.kmeans import compute_new_centroids from pylibraft.distance import fused_l2_nn_argmin # A single RAFT handle can optionally be reused across @@ -129,7 +129,7 @@ def compute_new_centroids(X, new_centroids = cp.empty((n_clusters, n_features), dtype=cp.float32) - compute_new_centroids(X, centroids, new_centroids, handle=handle) + compute_new_centroids(X, centroids, labels, new_centroids, handle=handle) # pylibraft functions are often asynchronous so the # handle needs to be explicitly synchronized diff --git a/python/pylibraft/pylibraft/common/cuda.pyx b/python/pylibraft/pylibraft/common/cuda.pyx index eb48f64cf1..9b35aebdba 100644 --- a/python/pylibraft/pylibraft/common/cuda.pyx +++ b/python/pylibraft/pylibraft/common/cuda.pyx @@ -52,7 +52,7 @@ cdef class Stream: .. code-block:: python - from raft.common.cuda import Stream + from pylibraft.common.cuda import Stream stream = Stream() stream.sync() del stream # optional! diff --git a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx index 7abc32119b..abb8f4d86e 100644 --- a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx +++ b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx @@ -79,7 +79,7 @@ def fused_l2_nn_argmin(X, Y, output, sqrt=True, handle=None): import cupy as cp from pylibraft.common import Handle - from pylibraft.distance import fused_l2_nn + from pylibraft.distance import fused_l2_nn_argmin n_samples = 5000 n_clusters = 5 From 9aef932c5a91db5b0fcd57a1b005c67245cfac7c Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 2 Nov 2022 15:11:25 -0700 Subject: [PATCH 2/2] Fix style --- python/pylibraft/pylibraft/cluster/kmeans.pyx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index e88813cfbb..39cfcda361 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -129,7 +129,9 @@ def compute_new_centroids(X, new_centroids = cp.empty((n_clusters, n_features), dtype=cp.float32) - compute_new_centroids(X, centroids, labels, new_centroids, handle=handle) + compute_new_centroids( + X, centroids, labels, new_centroids, handle=handle + ) # pylibraft functions are often asynchronous so the # handle needs to be explicitly synchronized