diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index c2d445f970..edbc30a1f0 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -27,6 +27,7 @@ from libcpp cimport bool from libcpp cimport nullptr from pylibraft.common import Handle +from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport handle_t from pylibraft.common.input_validation import * from pylibraft.distance import DISTANCE_TYPES @@ -66,6 +67,7 @@ cdef extern from "raft_distance/kmeans.hpp" \ float *weight_per_cluster) +@auto_sync_handle def compute_new_centroids(X, centroids, labels, @@ -97,7 +99,7 @@ def compute_new_centroids(X, distances in batches. default: m batch_centroids : Optional integer specifying the batch size for centroids to compute distances in batches. default: n_clusters - handle : Optional RAFT handle for reusing expensive CUDA resources + {handle_docstring} Examples -------- diff --git a/python/pylibraft/pylibraft/common/handle.pyx b/python/pylibraft/pylibraft/common/handle.pyx index 83a4676076..3e9ed569ad 100644 --- a/python/pylibraft/pylibraft/common/handle.pyx +++ b/python/pylibraft/pylibraft/common/handle.pyx @@ -19,7 +19,8 @@ # cython: embedsignature = True # cython: language_level = 3 -# import raft +import functools + from rmm._lib.cuda_stream_view cimport cuda_stream_per_thread from rmm._lib.cuda_stream_view cimport cuda_stream_view @@ -87,3 +88,41 @@ cdef class Handle: self.c_obj.reset(new handle_t(cuda_stream_per_thread, self.stream_pool)) + + +_HANDLE_PARAM_DOCSTRING = """ + handle : Optional RAFT handle for reusing expensive CUDA resources + If a handle isn't supplied, CUDA resources will be allocated + inside this function and synchronized before the function exits. + If a handle is supplied, you will need to explicitly synchronize + yourself by calling `handle.sync()` before accessing the output. +""".strip() + + +def auto_sync_handle(f): + """Decorator to automatically call sync on a raft handle when + it isn't passed to a function. + + When a handle=None is passed to the wrapped function, this decorator + will automatically create a default handle for the function, and + call sync on that handle when the function exits. + + This will also insert the appropriate docstring for the handle parameter + """ + + @functools.wraps(f) + def wrapper(*args, handle=None, **kwargs): + sync_handle = handle is None + handle = handle if handle is not None else Handle() + + ret_value = f(*args, handle=handle, **kwargs) + + if sync_handle: + handle.sync() + + return ret_value + + wrapper.__doc__ = wrapper.__doc__.format( + handle_docstring=_HANDLE_PARAM_DOCSTRING + ) + return wrapper diff --git a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx index 7abc32119b..7e05e413aa 100644 --- a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx +++ b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx @@ -26,6 +26,7 @@ from cython.operator cimport dereference as deref from libcpp cimport bool from .distance_type cimport DistanceType from pylibraft.common import Handle +from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport handle_t @@ -59,6 +60,7 @@ cdef extern from "raft_distance/fused_l2_min_arg.hpp" \ bool sqrt) +@auto_sync_handle def fused_l2_nn_argmin(X, Y, output, sqrt=True, handle=None): """ Compute the 1-nearest neighbors between X and Y using the L2 distance @@ -69,7 +71,7 @@ def fused_l2_nn_argmin(X, Y, output, sqrt=True, handle=None): X : CUDA array interface compliant matrix shape (m, k) Y : CUDA array interface compliant matrix shape (n, k) output : Writable CUDA array interface matrix shape (m, 1) - handle : Optional RAFT handle for reusing expensive CUDA resources + {handle_docstring} Examples -------- diff --git a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx index 0f7626e8d1..40f2fa668c 100644 --- a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx +++ b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx @@ -27,6 +27,7 @@ from libcpp cimport bool from .distance_type cimport DistanceType from pylibraft.common import Handle +from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport handle_t @@ -90,6 +91,7 @@ SUPPORTED_DISTANCES = ["euclidean", "l1", "cityblock", "l2", "inner_product", "hamming", "jensenshannon", "cosine", "sqeuclidean"] +@auto_sync_handle def distance(X, Y, dists, metric="euclidean", p=2.0, handle=None): """ Compute pairwise distances between X and Y @@ -108,7 +110,7 @@ def distance(X, Y, dists, metric="euclidean", p=2.0, handle=None): dists : Writable CUDA array interface matrix shape (m, n) metric : string denoting the metric type (default="euclidean") p : metric parameter (currently used only for "minkowski") - handle : Optional RAFT handle for reusing expensive CUDA resources + {handle_docstring} Examples -------- diff --git a/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx b/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx index cef19295ac..0aa90b6bc4 100644 --- a/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx +++ b/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx @@ -23,6 +23,7 @@ import numpy as np from libc.stdint cimport uintptr_t, int64_t from cython.operator cimport dereference as deref from pylibraft.common import Handle +from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport handle_t from .rng_state cimport RngState @@ -73,6 +74,7 @@ cdef extern from "raft_distance/random/rmat_rectangular_generator.hpp" \ RngState& r) +@auto_sync_handle def rmat(out, theta, r_scale, c_scale, seed=12345, handle=None): """ Generate RMAT adjacency list based on the input distribution. @@ -88,7 +90,7 @@ def rmat(out, theta, r_scale, c_scale, seed=12345, handle=None): r_scale: log2 of number of source nodes c_scale: log2 of number of destination nodes seed: random seed used for reproducibility - handle : Optional RAFT handle for reusing expensive CUDA resources + {handle_docstring} Examples --------