-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding handle and stream to pylibraft (#683)
Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Brad Rees (https://github.com/BradReesWork) URL: #683
- Loading branch information
Showing
3 changed files
with
174 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
# cython: profile=False | ||
# distutils: language = c++ | ||
# cython: embedsignature = True | ||
# cython: language_level = 3 | ||
|
||
from cuda.ccudart cimport( | ||
cudaStream_t, | ||
cudaError_t, | ||
cudaSuccess, | ||
cudaStreamCreate, | ||
cudaStreamDestroy, | ||
cudaStreamSynchronize, | ||
cudaGetLastError, | ||
cudaGetErrorString, | ||
cudaGetErrorName | ||
) | ||
|
||
|
||
class CudaRuntimeError(RuntimeError): | ||
def __init__(self, extraMsg=None): | ||
cdef cudaError_t e = cudaGetLastError() | ||
cdef bytes errMsg = cudaGetErrorString(e) | ||
cdef bytes errName = cudaGetErrorName(e) | ||
msg = "Error! %s reason='%s'" % (errName.decode(), errMsg.decode()) | ||
if extraMsg is not None: | ||
msg += " extraMsg='%s'" % extraMsg | ||
super(CudaRuntimeError, self).__init__(msg) | ||
|
||
|
||
cdef class Stream: | ||
""" | ||
Stream represents a thin-wrapper around cudaStream_t and its operations. | ||
Examples | ||
-------- | ||
.. code-block:: python | ||
from raft.common.cuda import Stream | ||
stream = Stream() | ||
stream.sync() | ||
del stream # optional! | ||
""" | ||
def __cinit__(self): | ||
cdef cudaStream_t stream | ||
cdef cudaError_t e = cudaStreamCreate(&stream) | ||
if e != cudaSuccess: | ||
raise CudaRuntimeError("Stream create") | ||
self.s = stream | ||
|
||
def __dealloc__(self): | ||
self.sync() | ||
cdef cudaError_t e = cudaStreamDestroy(self.s) | ||
if e != cudaSuccess: | ||
raise CudaRuntimeError("Stream destroy") | ||
|
||
def sync(self): | ||
""" | ||
Synchronize on the cudastream owned by this object. Note that this | ||
could raise exception due to issues with previous asynchronous | ||
launches | ||
""" | ||
cdef cudaError_t e = cudaStreamSynchronize(self.s) | ||
if e != cudaSuccess: | ||
raise CudaRuntimeError("Stream sync") | ||
|
||
cdef cudaStream_t getStream(self): | ||
return self.s |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
# cython: profile=False | ||
# distutils: language = c++ | ||
# cython: embedsignature = True | ||
# cython: language_level = 3 | ||
|
||
# import raft | ||
from rmm._lib.cuda_stream_view cimport cuda_stream_per_thread | ||
from rmm._lib.cuda_stream_view cimport cuda_stream_view | ||
|
||
from .cuda cimport Stream | ||
from .cuda import CudaRuntimeError | ||
|
||
|
||
cdef class Handle: | ||
""" | ||
Handle is a lightweight python wrapper around the corresponding C++ class | ||
of handle_t exposed by RAFT's C++ interface. Refer to the header file | ||
raft/handle.hpp for interface level details of this struct | ||
Examples | ||
-------- | ||
.. code-block:: python | ||
from raft.common import Stream, Handle | ||
stream = Stream() | ||
handle = Handle(stream) | ||
# call algos here | ||
# final sync of all work launched in the stream of this handle | ||
# this is same as `raft.cuda.Stream.sync()` call, but safer in case | ||
# the default stream inside the `handle_t` is being used | ||
handle.sync() | ||
del handle # optional! | ||
""" | ||
|
||
def __cinit__(self, stream: Stream = None, n_streams=0): | ||
self.n_streams = n_streams | ||
if n_streams > 0: | ||
self.stream_pool.reset(new cuda_stream_pool(n_streams)) | ||
|
||
cdef cuda_stream_view c_stream | ||
if stream is None: | ||
# this constructor will construct a "main" handle on | ||
# per-thread default stream, which is non-blocking | ||
self.c_obj.reset(new handle_t(cuda_stream_per_thread, | ||
self.stream_pool)) | ||
else: | ||
# this constructor constructs a handle on user stream | ||
c_stream = cuda_stream_view(stream.getStream()) | ||
self.c_obj.reset(new handle_t(c_stream, | ||
self.stream_pool)) | ||
|
||
def sync(self): | ||
""" | ||
Issues a sync on the stream set for this handle. | ||
""" | ||
self.c_obj.get()[0].sync_stream() | ||
|
||
def getHandle(self): | ||
return <size_t> self.c_obj.get() | ||
|
||
def __getstate__(self): | ||
return self.n_streams | ||
|
||
def __setstate__(self, state): | ||
self.n_streams = state | ||
if self.n_streams > 0: | ||
self.stream_pool.reset(new cuda_stream_pool(self.n_streams)) | ||
|
||
self.c_obj.reset(new handle_t(cuda_stream_per_thread, | ||
self.stream_pool)) |