diff --git a/python/pylibraft/pylibraft/common/cuda.pyx b/python/pylibraft/pylibraft/common/cuda.pyx new file mode 100644 index 0000000000..eb48f64cf1 --- /dev/null +++ b/python/pylibraft/pylibraft/common/cuda.pyx @@ -0,0 +1,84 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from cuda.ccudart cimport( + cudaStream_t, + cudaError_t, + cudaSuccess, + cudaStreamCreate, + cudaStreamDestroy, + cudaStreamSynchronize, + cudaGetLastError, + cudaGetErrorString, + cudaGetErrorName +) + + +class CudaRuntimeError(RuntimeError): + def __init__(self, extraMsg=None): + cdef cudaError_t e = cudaGetLastError() + cdef bytes errMsg = cudaGetErrorString(e) + cdef bytes errName = cudaGetErrorName(e) + msg = "Error! %s reason='%s'" % (errName.decode(), errMsg.decode()) + if extraMsg is not None: + msg += " extraMsg='%s'" % extraMsg + super(CudaRuntimeError, self).__init__(msg) + + +cdef class Stream: + """ + Stream represents a thin-wrapper around cudaStream_t and its operations. + + Examples + -------- + + .. code-block:: python + + from raft.common.cuda import Stream + stream = Stream() + stream.sync() + del stream # optional! + """ + def __cinit__(self): + cdef cudaStream_t stream + cdef cudaError_t e = cudaStreamCreate(&stream) + if e != cudaSuccess: + raise CudaRuntimeError("Stream create") + self.s = stream + + def __dealloc__(self): + self.sync() + cdef cudaError_t e = cudaStreamDestroy(self.s) + if e != cudaSuccess: + raise CudaRuntimeError("Stream destroy") + + def sync(self): + """ + Synchronize on the cudastream owned by this object. Note that this + could raise exception due to issues with previous asynchronous + launches + """ + cdef cudaError_t e = cudaStreamSynchronize(self.s) + if e != cudaSuccess: + raise CudaRuntimeError("Stream sync") + + cdef cudaStream_t getStream(self): + return self.s diff --git a/python/pylibraft/pylibraft/common/handle.pxd b/python/pylibraft/pylibraft/common/handle.pxd index bc248a335b..6504a122f7 100644 --- a/python/pylibraft/pylibraft/common/handle.pxd +++ b/python/pylibraft/pylibraft/common/handle.pxd @@ -25,7 +25,7 @@ from rmm._lib.cuda_stream_pool cimport cuda_stream_pool from libcpp.memory cimport shared_ptr from libcpp.memory cimport unique_ptr -cdef extern from "raft/handle.hpp" namespace "raft" nogil: +cdef extern from "raft/core/handle.hpp" namespace "raft" nogil: cdef cppclass handle_t: handle_t() except + handle_t(cuda_stream_view stream_view) except + diff --git a/python/pylibraft/pylibraft/common/handle.pyx b/python/pylibraft/pylibraft/common/handle.pyx new file mode 100644 index 0000000000..83a4676076 --- /dev/null +++ b/python/pylibraft/pylibraft/common/handle.pyx @@ -0,0 +1,89 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +# import raft +from rmm._lib.cuda_stream_view cimport cuda_stream_per_thread +from rmm._lib.cuda_stream_view cimport cuda_stream_view + +from .cuda cimport Stream +from .cuda import CudaRuntimeError + + +cdef class Handle: + """ + Handle is a lightweight python wrapper around the corresponding C++ class + of handle_t exposed by RAFT's C++ interface. Refer to the header file + raft/handle.hpp for interface level details of this struct + + Examples + -------- + + .. code-block:: python + + from raft.common import Stream, Handle + stream = Stream() + handle = Handle(stream) + + # call algos here + + # final sync of all work launched in the stream of this handle + # this is same as `raft.cuda.Stream.sync()` call, but safer in case + # the default stream inside the `handle_t` is being used + handle.sync() + del handle # optional! + """ + + def __cinit__(self, stream: Stream = None, n_streams=0): + self.n_streams = n_streams + if n_streams > 0: + self.stream_pool.reset(new cuda_stream_pool(n_streams)) + + cdef cuda_stream_view c_stream + if stream is None: + # this constructor will construct a "main" handle on + # per-thread default stream, which is non-blocking + self.c_obj.reset(new handle_t(cuda_stream_per_thread, + self.stream_pool)) + else: + # this constructor constructs a handle on user stream + c_stream = cuda_stream_view(stream.getStream()) + self.c_obj.reset(new handle_t(c_stream, + self.stream_pool)) + + def sync(self): + """ + Issues a sync on the stream set for this handle. + """ + self.c_obj.get()[0].sync_stream() + + def getHandle(self): + return self.c_obj.get() + + def __getstate__(self): + return self.n_streams + + def __setstate__(self, state): + self.n_streams = state + if self.n_streams > 0: + self.stream_pool.reset(new cuda_stream_pool(self.n_streams)) + + self.c_obj.reset(new handle_t(cuda_stream_per_thread, + self.stream_pool))