Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically sync handle when not passed to pylibraft functions #987

Merged
merged 4 commits into from
Nov 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/pylibraft/pylibraft/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ from libcpp cimport bool
from libcpp cimport nullptr

from pylibraft.common import Handle
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.handle cimport handle_t
from pylibraft.common.input_validation import *
from pylibraft.distance import DISTANCE_TYPES
Expand Down Expand Up @@ -66,6 +67,7 @@ cdef extern from "raft_distance/kmeans.hpp" \
float *weight_per_cluster)


@auto_sync_handle
def compute_new_centroids(X,
centroids,
labels,
Expand Down Expand Up @@ -97,7 +99,7 @@ def compute_new_centroids(X,
distances in batches. default: m
batch_centroids : Optional integer specifying the batch size for centroids
to compute distances in batches. default: n_clusters
handle : Optional RAFT handle for reusing expensive CUDA resources
{handle_docstring}
Examples
--------
Expand Down
41 changes: 40 additions & 1 deletion python/pylibraft/pylibraft/common/handle.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
# cython: embedsignature = True
# cython: language_level = 3

# import raft
import functools

from rmm._lib.cuda_stream_view cimport cuda_stream_per_thread
from rmm._lib.cuda_stream_view cimport cuda_stream_view

Expand Down Expand Up @@ -87,3 +88,41 @@ cdef class Handle:

self.c_obj.reset(new handle_t(cuda_stream_per_thread,
self.stream_pool))


_HANDLE_PARAM_DOCSTRING = """
handle : Optional RAFT handle for reusing expensive CUDA resources
If a handle isn't supplied, CUDA resources will be allocated
inside this function and synchronized before the function exits.
If a handle is supplied, you will need to explicitly synchronize
yourself by calling `handle.sync()` before accessing the output.
""".strip()


def auto_sync_handle(f):
"""Decorator to automatically call sync on a raft handle when
it isn't passed to a function.
When a handle=None is passed to the wrapped function, this decorator
will automatically create a default handle for the function, and
call sync on that handle when the function exits.
This will also insert the appropriate docstring for the handle parameter
"""

@functools.wraps(f)
def wrapper(*args, handle=None, **kwargs):
sync_handle = handle is None
handle = handle if handle is not None else Handle()

ret_value = f(*args, handle=handle, **kwargs)

if sync_handle:
handle.sync()

return ret_value

wrapper.__doc__ = wrapper.__doc__.format(
handle_docstring=_HANDLE_PARAM_DOCSTRING
)
return wrapper
4 changes: 3 additions & 1 deletion python/pylibraft/pylibraft/distance/fused_l2_nn.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ from cython.operator cimport dereference as deref
from libcpp cimport bool
from .distance_type cimport DistanceType
from pylibraft.common import Handle
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.handle cimport handle_t


Expand Down Expand Up @@ -59,6 +60,7 @@ cdef extern from "raft_distance/fused_l2_min_arg.hpp" \
bool sqrt)


@auto_sync_handle
def fused_l2_nn_argmin(X, Y, output, sqrt=True, handle=None):
"""
Compute the 1-nearest neighbors between X and Y using the L2 distance
Expand All @@ -69,7 +71,7 @@ def fused_l2_nn_argmin(X, Y, output, sqrt=True, handle=None):
X : CUDA array interface compliant matrix shape (m, k)
Y : CUDA array interface compliant matrix shape (n, k)
output : Writable CUDA array interface matrix shape (m, 1)
handle : Optional RAFT handle for reusing expensive CUDA resources
{handle_docstring}
Examples
--------
Expand Down
4 changes: 3 additions & 1 deletion python/pylibraft/pylibraft/distance/pairwise_distance.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ from libcpp cimport bool
from .distance_type cimport DistanceType

from pylibraft.common import Handle
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.handle cimport handle_t


Expand Down Expand Up @@ -90,6 +91,7 @@ SUPPORTED_DISTANCES = ["euclidean", "l1", "cityblock", "l2", "inner_product",
"hamming", "jensenshannon", "cosine", "sqeuclidean"]


@auto_sync_handle
def distance(X, Y, dists, metric="euclidean", p=2.0, handle=None):
"""
Compute pairwise distances between X and Y
Expand All @@ -108,7 +110,7 @@ def distance(X, Y, dists, metric="euclidean", p=2.0, handle=None):
dists : Writable CUDA array interface matrix shape (m, n)
metric : string denoting the metric type (default="euclidean")
p : metric parameter (currently used only for "minkowski")
handle : Optional RAFT handle for reusing expensive CUDA resources
{handle_docstring}

Examples
--------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import numpy as np
from libc.stdint cimport uintptr_t, int64_t
from cython.operator cimport dereference as deref
from pylibraft.common import Handle
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.handle cimport handle_t
from .rng_state cimport RngState

Expand Down Expand Up @@ -73,6 +74,7 @@ cdef extern from "raft_distance/random/rmat_rectangular_generator.hpp" \
RngState& r)


@auto_sync_handle
def rmat(out, theta, r_scale, c_scale, seed=12345, handle=None):
"""
Generate RMAT adjacency list based on the input distribution.
Expand All @@ -88,7 +90,7 @@ def rmat(out, theta, r_scale, c_scale, seed=12345, handle=None):
r_scale: log2 of number of source nodes
c_scale: log2 of number of destination nodes
seed: random seed used for reproducibility
handle : Optional RAFT handle for reusing expensive CUDA resources
{handle_docstring}
Examples
--------
Expand Down