Skip to content

Commit

Permalink
Automatically sync handle when not passed to pylibraft functions
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Nov 3, 2022
1 parent d6234f6 commit 5950a6d
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 1 deletion.
2 changes: 2 additions & 0 deletions python/pylibraft/pylibraft/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ from libcpp cimport bool
from libcpp cimport nullptr

from pylibraft.common import Handle
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.handle cimport handle_t
from pylibraft.common.input_validation import *
from pylibraft.distance import DISTANCE_TYPES
Expand Down Expand Up @@ -66,6 +67,7 @@ cdef extern from "raft_distance/kmeans.hpp" \
float *weight_per_cluster)


@auto_sync_handle
def compute_new_centroids(X,
centroids,
labels,
Expand Down
27 changes: 26 additions & 1 deletion python/pylibraft/pylibraft/common/handle.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
# cython: embedsignature = True
# cython: language_level = 3

# import raft
import functools

from rmm._lib.cuda_stream_view cimport cuda_stream_per_thread
from rmm._lib.cuda_stream_view cimport cuda_stream_view

Expand Down Expand Up @@ -87,3 +88,27 @@ cdef class Handle:

self.c_obj.reset(new handle_t(cuda_stream_per_thread,
self.stream_pool))


def auto_sync_handle(f):
"""Decorator to automatically call sync on a raft handle when
it isn't passed to a function.
When a handle=None is passed to the wrapped function, this decorator
will automatically create a default handle for the function, and
call sync on that handle when the function exits.
"""

@functools.wraps(f)
def wrapper(*args, handle=None, **kwargs):
sync_handle = handle is None
handle = handle if handle is not None else Handle()

ret_value = f(*args, handle=handle, **kwargs)

if sync_handle:
handle.sync()

return ret_value

return wrapper
2 changes: 2 additions & 0 deletions python/pylibraft/pylibraft/distance/fused_l2_nn.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ from cython.operator cimport dereference as deref
from libcpp cimport bool
from .distance_type cimport DistanceType
from pylibraft.common import Handle
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.handle cimport handle_t


Expand Down Expand Up @@ -59,6 +60,7 @@ cdef extern from "raft_distance/fused_l2_min_arg.hpp" \
bool sqrt)


@auto_sync_handle
def fused_l2_nn_argmin(X, Y, output, sqrt=True, handle=None):
"""
Compute the 1-nearest neighbors between X and Y using the L2 distance
Expand Down
2 changes: 2 additions & 0 deletions python/pylibraft/pylibraft/distance/pairwise_distance.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ from libcpp cimport bool
from .distance_type cimport DistanceType

from pylibraft.common import Handle
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.handle cimport handle_t


Expand Down Expand Up @@ -90,6 +91,7 @@ SUPPORTED_DISTANCES = ["euclidean", "l1", "cityblock", "l2", "inner_product",
"hamming", "jensenshannon", "cosine", "sqeuclidean"]


@auto_sync_handle
def distance(X, Y, dists, metric="euclidean", p=2.0, handle=None):
"""
Compute pairwise distances between X and Y
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import numpy as np
from libc.stdint cimport uintptr_t, int64_t
from cython.operator cimport dereference as deref
from pylibraft.common import Handle
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.handle cimport handle_t
from .rng_state cimport RngState

Expand Down Expand Up @@ -73,6 +74,7 @@ cdef extern from "raft_distance/random/rmat_rectangular_generator.hpp" \
RngState& r)


@auto_sync_handle
def rmat(out, theta, r_scale, c_scale, seed=12345, handle=None):
"""
Generate RMAT adjacency list based on the input distribution.
Expand Down

0 comments on commit 5950a6d

Please sign in to comment.