From 5950a6d05e7b0b1e180398ed8766d2d5e88219a6 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 3 Nov 2022 16:50:52 -0700 Subject: [PATCH] Automatically sync handle when not passed to pylibraft functions Closes https://github.com/rapidsai/raft/issues/971 --- python/pylibraft/pylibraft/cluster/kmeans.pyx | 2 ++ python/pylibraft/pylibraft/common/handle.pyx | 27 ++++++++++++++++++- .../pylibraft/distance/fused_l2_nn.pyx | 2 ++ .../pylibraft/distance/pairwise_distance.pyx | 2 ++ .../random/rmat_rectangular_generator.pyx | 2 ++ 5 files changed, 34 insertions(+), 1 deletion(-) diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index c2d445f970..06e6338a33 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -27,6 +27,7 @@ from libcpp cimport bool from libcpp cimport nullptr from pylibraft.common import Handle +from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport handle_t from pylibraft.common.input_validation import * from pylibraft.distance import DISTANCE_TYPES @@ -66,6 +67,7 @@ cdef extern from "raft_distance/kmeans.hpp" \ float *weight_per_cluster) +@auto_sync_handle def compute_new_centroids(X, centroids, labels, diff --git a/python/pylibraft/pylibraft/common/handle.pyx b/python/pylibraft/pylibraft/common/handle.pyx index 83a4676076..cd1c717d07 100644 --- a/python/pylibraft/pylibraft/common/handle.pyx +++ b/python/pylibraft/pylibraft/common/handle.pyx @@ -19,7 +19,8 @@ # cython: embedsignature = True # cython: language_level = 3 -# import raft +import functools + from rmm._lib.cuda_stream_view cimport cuda_stream_per_thread from rmm._lib.cuda_stream_view cimport cuda_stream_view @@ -87,3 +88,27 @@ cdef class Handle: self.c_obj.reset(new handle_t(cuda_stream_per_thread, self.stream_pool)) + + +def auto_sync_handle(f): + """Decorator to automatically call sync on a raft handle when + it isn't passed to a function. + + When a handle=None is passed to the wrapped function, this decorator + will automatically create a default handle for the function, and + call sync on that handle when the function exits. + """ + + @functools.wraps(f) + def wrapper(*args, handle=None, **kwargs): + sync_handle = handle is None + handle = handle if handle is not None else Handle() + + ret_value = f(*args, handle=handle, **kwargs) + + if sync_handle: + handle.sync() + + return ret_value + + return wrapper diff --git a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx index 7abc32119b..c78323dda1 100644 --- a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx +++ b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx @@ -26,6 +26,7 @@ from cython.operator cimport dereference as deref from libcpp cimport bool from .distance_type cimport DistanceType from pylibraft.common import Handle +from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport handle_t @@ -59,6 +60,7 @@ cdef extern from "raft_distance/fused_l2_min_arg.hpp" \ bool sqrt) +@auto_sync_handle def fused_l2_nn_argmin(X, Y, output, sqrt=True, handle=None): """ Compute the 1-nearest neighbors between X and Y using the L2 distance diff --git a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx index 0f7626e8d1..8e84f1a204 100644 --- a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx +++ b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx @@ -27,6 +27,7 @@ from libcpp cimport bool from .distance_type cimport DistanceType from pylibraft.common import Handle +from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport handle_t @@ -90,6 +91,7 @@ SUPPORTED_DISTANCES = ["euclidean", "l1", "cityblock", "l2", "inner_product", "hamming", "jensenshannon", "cosine", "sqeuclidean"] +@auto_sync_handle def distance(X, Y, dists, metric="euclidean", p=2.0, handle=None): """ Compute pairwise distances between X and Y diff --git a/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx b/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx index cef19295ac..f5a843215a 100644 --- a/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx +++ b/python/pylibraft/pylibraft/random/rmat_rectangular_generator.pyx @@ -23,6 +23,7 @@ import numpy as np from libc.stdint cimport uintptr_t, int64_t from cython.operator cimport dereference as deref from pylibraft.common import Handle +from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport handle_t from .rng_state cimport RngState @@ -73,6 +74,7 @@ cdef extern from "raft_distance/random/rmat_rectangular_generator.hpp" \ RngState& r) +@auto_sync_handle def rmat(out, theta, r_scale, c_scale, seed=12345, handle=None): """ Generate RMAT adjacency list based on the input distribution.