From 5950a6d05e7b0b1e180398ed8766d2d5e88219a6 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 3 Nov 2022 16:50:52 -0700 Subject: [PATCH 1/4] Automatically sync handle when not passed to pylibraft functions Closes https://github.com/rapidsai/raft/issues/971 --- python/pylibraft/pylibraft/cluster/kmeans.pyx | 2 ++ python/pylibraft/pylibraft/common/handle.pyx | 27 ++++++++++++++++++- .../pylibraft/distance/fused_l2_nn.pyx | 2 ++ .../pylibraft/distance/pairwise_distance.pyx | 2 ++ .../random/rmat_rectangular_generator.pyx | 2 ++ 5 files changed, 34 insertions(+), 1 deletion(-) diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index c2d445f970..06e6338a33 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -27,6 +27,7 @@ from libcpp cimport bool from libcpp cimport nullptr from pylibraft.common import Handle +from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport handle_t from pylibraft.common.input_validation import * from pylibraft.distance import DISTANCE_TYPES @@ -66,6 +67,7 @@ cdef extern from "raft_distance/kmeans.hpp" \ float *weight_per_cluster) +@auto_sync_handle def compute_new_centroids(X, centroids, labels, diff --git a/python/pylibraft/pylibraft/common/handle.pyx b/python/pylibraft/pylibraft/common/handle.pyx index 83a4676076..cd1c717d07 100644 --- a/python/pylibraft/pylibraft/common/handle.pyx +++ b/python/pylibraft/pylibraft/common/handle.pyx @@ -19,7 +19,8 @@ # cython: embedsignature = True # cython: language_level = 3 -# import raft +import functools + from rmm._lib.cuda_stream_view cimport cuda_stream_per_thread from rmm._lib.cuda_stream_view cimport cuda_stream_view @@ -87,3 +88,27 @@ cdef class Handle: self.c_obj.reset(new handle_t(cuda_stream_per_thread, self.stream_pool)) + + +def auto_sync_handle(f): + """Decorator to automatically call sync on a raft handle when + it isn't passed to a function. + + When a handle=None is passed to the wrapped function, this decorator + will automatically create a default handle for the function, and + call sync on that handle when the function exits. + """ + + @functools.wraps(f) + def wrapper(*args, handle=None, **kwargs): + sync_handle = handle is None + handle = handle if handle is not None else Handle() + + ret_value = f(*args, handle=handle, **kwargs) + + if sync_handle: + handle.sync() + + return ret_value + + return wrapper diff --git a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx index 7abc32119b..c78323dda1 100644 --- a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx +++ b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx @@ -26,6 +26,7 @@ from cython.operator cimport dereference as deref from libcpp cimport bool from .distance_type cimport DistanceType from pylibraft.common import Handle +from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport handle_t @@ -59,6 +60,7 @@ cdef extern from "raft_distance/fused_l2_min_arg.hpp" \ bool sqrt) +@auto_sync_handle def fused_l2_nn_argmin(X, Y, output, sqrt=True, handle=None): """ Compute the 1-nearest neighbors between X and Y using the L2 distance diff --git a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx index 0f7626e8d1..8e84f1a204 100644 --- a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx +++ b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx @@ -27,6 +27,7 @@ from libcpp cimport bool from .distance_type cimport DistanceType from pylibraft.common import Handle +from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport handle_t @@ -90,6 +91,7 @@ SUPPORTED_DISTANCES = ["euclidean", "l1", "cityblock", "l2", "inner_product", "hamming", "jensenshannon", "cosine", "sqeuclidean"] +@auto_sync_handle def distance(X, Y, dists, metric="euclidean", p=2.0, handle=None): """ Compute pairwise distances between X and Y diff --git a/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx b/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx index cef19295ac..f5a843215a 100644 --- a/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx +++ b/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx @@ -23,6 +23,7 @@ import numpy as np from libc.stdint cimport uintptr_t, int64_t from cython.operator cimport dereference as deref from pylibraft.common import Handle +from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport handle_t from .rng_state cimport RngState @@ -73,6 +74,7 @@ cdef extern from "raft_distance/random/rmat_rectangular_generator.hpp" \ RngState& r) +@auto_sync_handle def rmat(out, theta, r_scale, c_scale, seed=12345, handle=None): """ Generate RMAT adjacency list based on the input distribution. From 1bc4c2a03beb6043ee32b2e7de73daf61b95c23b Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 3 Nov 2022 17:01:23 -0700 Subject: [PATCH 2/4] style --- python/pylibraft/pylibraft/common/handle.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pylibraft/pylibraft/common/handle.pyx b/python/pylibraft/pylibraft/common/handle.pyx index cd1c717d07..7221c5f122 100644 --- a/python/pylibraft/pylibraft/common/handle.pyx +++ b/python/pylibraft/pylibraft/common/handle.pyx @@ -95,7 +95,7 @@ def auto_sync_handle(f): it isn't passed to a function. When a handle=None is passed to the wrapped function, this decorator - will automatically create a default handle for the function, and + will automatically create a default handle for the function, and call sync on that handle when the function exits. """ From 8685ee3a592837b2b302382c28dab5418b53def8 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 4 Nov 2022 11:03:11 -0700 Subject: [PATCH 3/4] Update docstring for handle --- python/pylibraft/pylibraft/cluster/kmeans.pyx | 2 +- python/pylibraft/pylibraft/common/handle.pyx | 12 ++++++++++++ python/pylibraft/pylibraft/distance/fused_l2_nn.pyx | 2 +- .../pylibraft/distance/pairwise_distance.pyx | 2 +- .../pylibraft/random/rmat_rectangular_generator.pyx | 2 +- 5 files changed, 16 insertions(+), 4 deletions(-) diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index 06e6338a33..edbc30a1f0 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -99,7 +99,7 @@ def compute_new_centroids(X, distances in batches. default: m batch_centroids : Optional integer specifying the batch size for centroids to compute distances in batches. default: n_clusters - handle : Optional RAFT handle for reusing expensive CUDA resources + {handle_docstring} Examples -------- diff --git a/python/pylibraft/pylibraft/common/handle.pyx b/python/pylibraft/pylibraft/common/handle.pyx index 7221c5f122..f820d7a96a 100644 --- a/python/pylibraft/pylibraft/common/handle.pyx +++ b/python/pylibraft/pylibraft/common/handle.pyx @@ -90,6 +90,15 @@ cdef class Handle: self.stream_pool)) +_HANDLE_PARAM_DOCSTRING = """ + handle : Optional RAFT handle for reusing expensive CUDA resources + If a handle isn't supplied, CUDA resources will be allocated + inside this function and synchronized before the function exits. + If a handle is supplied, you will need to explicitly synchronize + yourself by calling `handle.sync()` before accessing the output. +""".strip() + + def auto_sync_handle(f): """Decorator to automatically call sync on a raft handle when it isn't passed to a function. @@ -97,6 +106,8 @@ def auto_sync_handle(f): When a handle=None is passed to the wrapped function, this decorator will automatically create a default handle for the function, and call sync on that handle when the function exits. + + This will also insert the appropriate docstring for the handle parameter """ @functools.wraps(f) @@ -111,4 +122,5 @@ def auto_sync_handle(f): return ret_value + wrapper.__doc__ = wrapper.__doc__.format(handle_docstring=_HANDLE_PARAM_DOCSTRING) return wrapper diff --git a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx index c78323dda1..7e05e413aa 100644 --- a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx +++ b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx @@ -71,7 +71,7 @@ def fused_l2_nn_argmin(X, Y, output, sqrt=True, handle=None): X : CUDA array interface compliant matrix shape (m, k) Y : CUDA array interface compliant matrix shape (n, k) output : Writable CUDA array interface matrix shape (m, 1) - handle : Optional RAFT handle for reusing expensive CUDA resources + {handle_docstring} Examples -------- diff --git a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx index 8e84f1a204..40f2fa668c 100644 --- a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx +++ b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx @@ -110,7 +110,7 @@ def distance(X, Y, dists, metric="euclidean", p=2.0, handle=None): dists : Writable CUDA array interface matrix shape (m, n) metric : string denoting the metric type (default="euclidean") p : metric parameter (currently used only for "minkowski") - handle : Optional RAFT handle for reusing expensive CUDA resources + {handle_docstring} Examples -------- diff --git a/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx b/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx index f5a843215a..0aa90b6bc4 100644 --- a/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx +++ b/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx @@ -90,7 +90,7 @@ def rmat(out, theta, r_scale, c_scale, seed=12345, handle=None): r_scale: log2 of number of source nodes c_scale: log2 of number of destination nodes seed: random seed used for reproducibility - handle : Optional RAFT handle for reusing expensive CUDA resources + {handle_docstring} Examples -------- From 1ddb9bf6eec9bafb89c2d84c71b6248a680e42e4 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 4 Nov 2022 11:41:28 -0700 Subject: [PATCH 4/4] style --- python/pylibraft/pylibraft/common/handle.pyx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pylibraft/pylibraft/common/handle.pyx b/python/pylibraft/pylibraft/common/handle.pyx index f820d7a96a..3e9ed569ad 100644 --- a/python/pylibraft/pylibraft/common/handle.pyx +++ b/python/pylibraft/pylibraft/common/handle.pyx @@ -122,5 +122,7 @@ def auto_sync_handle(f): return ret_value - wrapper.__doc__ = wrapper.__doc__.format(handle_docstring=_HANDLE_PARAM_DOCSTRING) + wrapper.__doc__ = wrapper.__doc__.format( + handle_docstring=_HANDLE_PARAM_DOCSTRING + ) return wrapper