From aee90f4518cf12475c86b02371918fd886f68a4a Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 6 Dec 2022 17:14:05 -0800 Subject: [PATCH] Minor cython fixes / cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Release GIL on C++ cython declarations * Remove ‘valid values for metric’ from compute_new_centroids (since doesn’t take a metric param) * Remove is_c_cont in favour of is_c_contiguous in kmeans.pyx --- .../pylibraft/cluster/cpp/kmeans.pxd | 2 +- python/pylibraft/pylibraft/cluster/kmeans.pyx | 19 +++++-------------- .../pylibraft/distance/fused_l2_nn.pyx | 2 +- .../pylibraft/distance/pairwise_distance.pyx | 2 +- .../random/rmat_rectangular_generator.pyx | 2 +- 5 files changed, 9 insertions(+), 18 deletions(-) diff --git a/python/pylibraft/pylibraft/cluster/cpp/kmeans.pxd b/python/pylibraft/pylibraft/cluster/cpp/kmeans.pxd index 40f84cad40..059512990e 100644 --- a/python/pylibraft/pylibraft/cluster/cpp/kmeans.pxd +++ b/python/pylibraft/pylibraft/cluster/cpp/kmeans.pxd @@ -31,7 +31,7 @@ from pylibraft.common.handle cimport handle_t cdef extern from "raft_runtime/cluster/kmeans.hpp" \ - namespace "raft::runtime::cluster::kmeans": + namespace "raft::runtime::cluster::kmeans" nogil: cdef void update_centroids( const handle_t& handle, diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index 7d336ab58d..ca25b45843 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -46,12 +46,6 @@ from pylibraft.common.cpp.optional cimport optional from pylibraft.common.handle cimport handle_t -def is_c_cont(cai, dt): - return "strides" not in cai or \ - cai["strides"] is None or \ - cai["strides"][1] == dt.itemsize - - @auto_sync_handle def compute_new_centroids(X, centroids, @@ -63,9 +57,6 @@ def compute_new_centroids(X, """ Compute new centroids given an input matrix and existing centroids - Valid values for metric: - ["euclidean", "sqeuclidean"] - Parameters ---------- @@ -167,9 +158,9 @@ def compute_new_centroids(X, handle = handle if handle is not None else Handle() cdef handle_t *h = handle.getHandle() - x_c_contiguous = is_c_cont(x_cai, x_dt) - centroids_c_contiguous = is_c_cont(centroids_cai, centroids_dt) - new_centroids_c_contiguous = is_c_cont(new_centroids_cai, new_centroids_dt) + x_c_contiguous = is_c_contiguous(x_cai) + centroids_c_contiguous = is_c_contiguous(centroids_cai) + new_centroids_c_contiguous = is_c_contiguous(new_centroids_cai) if not x_c_contiguous or not centroids_c_contiguous \ or not new_centroids_c_contiguous: @@ -258,8 +249,8 @@ def cluster_cost(X, centroids, handle=None): handle = handle if handle is not None else Handle() cdef handle_t *h = handle.getHandle() - x_c_contiguous = is_c_cont(x_cai, x_dt) - centroids_c_contiguous = is_c_cont(centroids_cai, centroids_dt) + x_c_contiguous = is_c_contiguous(x_cai) + centroids_c_contiguous = is_c_contiguous(centroids_cai) if not x_c_contiguous or not centroids_c_contiguous: raise ValueError("Inputs must all be c contiguous") diff --git a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx index e22afa99bf..81a81b2632 100644 --- a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx +++ b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx @@ -33,7 +33,7 @@ from pylibraft.common.handle cimport handle_t cdef extern from "raft_runtime/distance/fused_l2_nn.hpp" \ - namespace "raft::runtime::distance": + namespace "raft::runtime::distance" nogil: void fused_l2_nn_min_arg( const handle_t &handle, diff --git a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx index 8296d5bc09..450444f953 100644 --- a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx +++ b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx @@ -35,7 +35,7 @@ from pylibraft.common import cai_wrapper, device_ndarray cdef extern from "raft_runtime/distance/pairwise_distance.hpp" \ - namespace "raft::runtime::distance": + namespace "raft::runtime::distance" nogil: cdef void pairwise_distance(const handle_t &handle, float *x, diff --git a/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx b/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx index dbc2e25a29..17c574bea5 100644 --- a/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx +++ b/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx @@ -33,7 +33,7 @@ from pylibraft.random.cpp.rng_state cimport RngState cdef extern from "raft_runtime/random/rmat_rectangular_generator.hpp" \ - namespace "raft::runtime::random": + namespace "raft::runtime::random" nogil: cdef void rmat_rectangular_gen(const handle_t &handle, int* out,