Skip to content

Commit

Permalink
Minor cython fixes / cleanup (rapidsai#1072)
Browse files Browse the repository at this point in the history
* Release GIL on C++ cython declarations
* Remove 'valid values for metric' mention from the compute_new_centroids docstring (since it doesn't take a metric parameter)
* Remove `is_c_cont` in favour of `input_validation.is_c_contiguous` in kmeans.pyx

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#1072
  • Loading branch information
benfred authored Dec 7, 2022
1 parent 092c515 commit dd49a10
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 18 deletions.
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/cluster/cpp/kmeans.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ from pylibraft.common.handle cimport handle_t


cdef extern from "raft_runtime/cluster/kmeans.hpp" \
namespace "raft::runtime::cluster::kmeans":
namespace "raft::runtime::cluster::kmeans" nogil:

cdef void update_centroids(
const handle_t& handle,
Expand Down
19 changes: 5 additions & 14 deletions python/pylibraft/pylibraft/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,6 @@ from pylibraft.common.cpp.optional cimport optional
from pylibraft.common.handle cimport handle_t


def is_c_cont(cai, dt):
return "strides" not in cai or \
cai["strides"] is None or \
cai["strides"][1] == dt.itemsize


@auto_sync_handle
def compute_new_centroids(X,
centroids,
Expand All @@ -63,9 +57,6 @@ def compute_new_centroids(X,
"""
Compute new centroids given an input matrix and existing centroids
Valid values for metric:
["euclidean", "sqeuclidean"]
Parameters
----------
Expand Down Expand Up @@ -167,9 +158,9 @@ def compute_new_centroids(X,
handle = handle if handle is not None else Handle()
cdef handle_t *h = <handle_t*><size_t>handle.getHandle()

x_c_contiguous = is_c_cont(x_cai, x_dt)
centroids_c_contiguous = is_c_cont(centroids_cai, centroids_dt)
new_centroids_c_contiguous = is_c_cont(new_centroids_cai, new_centroids_dt)
x_c_contiguous = is_c_contiguous(x_cai)
centroids_c_contiguous = is_c_contiguous(centroids_cai)
new_centroids_c_contiguous = is_c_contiguous(new_centroids_cai)

if not x_c_contiguous or not centroids_c_contiguous \
or not new_centroids_c_contiguous:
Expand Down Expand Up @@ -258,8 +249,8 @@ def cluster_cost(X, centroids, handle=None):
handle = handle if handle is not None else Handle()
cdef handle_t *h = <handle_t*><size_t>handle.getHandle()

x_c_contiguous = is_c_cont(x_cai, x_dt)
centroids_c_contiguous = is_c_cont(centroids_cai, centroids_dt)
x_c_contiguous = is_c_contiguous(x_cai)
centroids_c_contiguous = is_c_contiguous(centroids_cai)

if not x_c_contiguous or not centroids_c_contiguous:
raise ValueError("Inputs must all be c contiguous")
Expand Down
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/distance/fused_l2_nn.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ from pylibraft.common.handle cimport handle_t


cdef extern from "raft_runtime/distance/fused_l2_nn.hpp" \
namespace "raft::runtime::distance":
namespace "raft::runtime::distance" nogil:

void fused_l2_nn_min_arg(
const handle_t &handle,
Expand Down
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/distance/pairwise_distance.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ from pylibraft.common import cai_wrapper, device_ndarray


cdef extern from "raft_runtime/distance/pairwise_distance.hpp" \
namespace "raft::runtime::distance":
namespace "raft::runtime::distance" nogil:

cdef void pairwise_distance(const handle_t &handle,
float *x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ from pylibraft.random.cpp.rng_state cimport RngState


cdef extern from "raft_runtime/random/rmat_rectangular_generator.hpp" \
namespace "raft::runtime::random":
namespace "raft::runtime::random" nogil:

cdef void rmat_rectangular_gen(const handle_t &handle,
int* out,
Expand Down

0 comments on commit dd49a10

Please sign in to comment.