Skip to content

Commit

Permalink
Automatically sync handle when not passed to pylibraft functions (#987)
Browse files Browse the repository at this point in the history
Closes #971

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #987
  • Loading branch information
benfred authored Nov 7, 2022
1 parent d6234f6 commit 1d78d88
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 5 deletions.
4 changes: 3 additions & 1 deletion python/pylibraft/pylibraft/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ from libcpp cimport bool
from libcpp cimport nullptr

from pylibraft.common import Handle
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.handle cimport handle_t
from pylibraft.common.input_validation import *
from pylibraft.distance import DISTANCE_TYPES
Expand Down Expand Up @@ -66,6 +67,7 @@ cdef extern from "raft_distance/kmeans.hpp" \
float *weight_per_cluster)


@auto_sync_handle
def compute_new_centroids(X,
centroids,
labels,
Expand Down Expand Up @@ -97,7 +99,7 @@ def compute_new_centroids(X,
distances in batches. default: m
batch_centroids : Optional integer specifying the batch size for centroids
to compute distances in batches. default: n_clusters
handle : Optional RAFT handle for reusing expensive CUDA resources
{handle_docstring}
Examples
--------
Expand Down
41 changes: 40 additions & 1 deletion python/pylibraft/pylibraft/common/handle.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
# cython: embedsignature = True
# cython: language_level = 3

# import raft
import functools

from rmm._lib.cuda_stream_view cimport cuda_stream_per_thread
from rmm._lib.cuda_stream_view cimport cuda_stream_view

Expand Down Expand Up @@ -87,3 +88,41 @@ cdef class Handle:

self.c_obj.reset(new handle_t(cuda_stream_per_thread,
self.stream_pool))


_HANDLE_PARAM_DOCSTRING = """
handle : Optional RAFT handle for reusing expensive CUDA resources
If a handle isn't supplied, CUDA resources will be allocated
inside this function and synchronized before the function exits.
If a handle is supplied, you will need to explicitly synchronize
yourself by calling `handle.sync()` before accessing the output.
""".strip()


def auto_sync_handle(f):
"""Decorator to automatically call sync on a raft handle when
it isn't passed to a function.
When a handle=None is passed to the wrapped function, this decorator
will automatically create a default handle for the function, and
call sync on that handle when the function exits.
This will also insert the appropriate docstring for the handle parameter
"""

@functools.wraps(f)
def wrapper(*args, handle=None, **kwargs):
sync_handle = handle is None
handle = handle if handle is not None else Handle()

ret_value = f(*args, handle=handle, **kwargs)

if sync_handle:
handle.sync()

return ret_value

wrapper.__doc__ = wrapper.__doc__.format(
handle_docstring=_HANDLE_PARAM_DOCSTRING
)
return wrapper
4 changes: 3 additions & 1 deletion python/pylibraft/pylibraft/distance/fused_l2_nn.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ from cython.operator cimport dereference as deref
from libcpp cimport bool
from .distance_type cimport DistanceType
from pylibraft.common import Handle
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.handle cimport handle_t


Expand Down Expand Up @@ -59,6 +60,7 @@ cdef extern from "raft_distance/fused_l2_min_arg.hpp" \
bool sqrt)


@auto_sync_handle
def fused_l2_nn_argmin(X, Y, output, sqrt=True, handle=None):
"""
Compute the 1-nearest neighbors between X and Y using the L2 distance
Expand All @@ -69,7 +71,7 @@ def fused_l2_nn_argmin(X, Y, output, sqrt=True, handle=None):
X : CUDA array interface compliant matrix shape (m, k)
Y : CUDA array interface compliant matrix shape (n, k)
output : Writable CUDA array interface matrix shape (m, 1)
handle : Optional RAFT handle for reusing expensive CUDA resources
{handle_docstring}
Examples
--------
Expand Down
4 changes: 3 additions & 1 deletion python/pylibraft/pylibraft/distance/pairwise_distance.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ from libcpp cimport bool
from .distance_type cimport DistanceType

from pylibraft.common import Handle
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.handle cimport handle_t


Expand Down Expand Up @@ -90,6 +91,7 @@ SUPPORTED_DISTANCES = ["euclidean", "l1", "cityblock", "l2", "inner_product",
"hamming", "jensenshannon", "cosine", "sqeuclidean"]


@auto_sync_handle
def distance(X, Y, dists, metric="euclidean", p=2.0, handle=None):
"""
Compute pairwise distances between X and Y
Expand All @@ -108,7 +110,7 @@ def distance(X, Y, dists, metric="euclidean", p=2.0, handle=None):
dists : Writable CUDA array interface matrix shape (m, n)
metric : string denoting the metric type (default="euclidean")
p : metric parameter (currently used only for "minkowski")
handle : Optional RAFT handle for reusing expensive CUDA resources
{handle_docstring}
Examples
--------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import numpy as np
from libc.stdint cimport uintptr_t, int64_t
from cython.operator cimport dereference as deref
from pylibraft.common import Handle
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.handle cimport handle_t
from .rng_state cimport RngState

Expand Down Expand Up @@ -73,6 +74,7 @@ cdef extern from "raft_distance/random/rmat_rectangular_generator.hpp" \
RngState& r)


@auto_sync_handle
def rmat(out, theta, r_scale, c_scale, seed=12345, handle=None):
"""
Generate RMAT adjacency list based on the input distribution.
Expand All @@ -88,7 +90,7 @@ def rmat(out, theta, r_scale, c_scale, seed=12345, handle=None):
r_scale: log2 of number of source nodes
c_scale: log2 of number of destination nodes
seed: random seed used for reproducibility
handle : Optional RAFT handle for reusing expensive CUDA resources
{handle_docstring}
Examples
--------
Expand Down

0 comments on commit 1d78d88

Please sign in to comment.