diff --git a/conda/environments/raft_dev_cuda11.5.yml b/conda/environments/raft_dev_cuda11.5.yml index 152f3a8db5..c6d9f3fbf5 100644 --- a/conda/environments/raft_dev_cuda11.5.yml +++ b/conda/environments/raft_dev_cuda11.5.yml @@ -6,6 +6,7 @@ channels: - conda-forge dependencies: - cudatoolkit=11.5 +- cuda-python >=11.5,<12.0 - clang=11.1.0 - clang-tools=11.1.0 - rapids-build-env=22.02.* diff --git a/python/raft/common/cuda.pxd b/python/raft/common/cuda.pxd index e407213f44..0459cb96af 100644 --- a/python/raft/common/cuda.pxd +++ b/python/raft/common/cuda.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2019, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,23 +14,9 @@ # limitations under the License. # -# cython: profile=False -# distutils: language = c++ -# cython: embedsignature = True -# cython: language_level = 3 +from cuda.ccudart cimport cudaStream_t +cdef class Stream: + cdef cudaStream_t s -# Populate this with more typedef's (eg: events) as and when needed -cdef extern from * nogil: - ctypedef void* _Stream "cudaStream_t" - ctypedef int _Error "cudaError_t" - - -# Populate this with more runtime api method declarations as and when needed -cdef extern from "cuda_runtime_api.h" nogil: - _Error cudaStreamCreate(_Stream* s) - _Error cudaStreamDestroy(_Stream s) - _Error cudaStreamSynchronize(_Stream s) - _Error cudaGetLastError() - const char* cudaGetErrorString(_Error e) - const char* cudaGetErrorName(_Error e) + cdef cudaStream_t getStream(self) diff --git a/python/raft/common/cuda.pyx b/python/raft/common/cuda.pyx index 0b97eeba67..c3c90936aa 100644 --- a/python/raft/common/cuda.pyx +++ b/python/raft/common/cuda.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,10 +19,22 @@ # cython: embedsignature = True # cython: language_level = 3 +from cuda.ccudart cimport( + cudaStream_t, + cudaError_t, + cudaSuccess, + cudaStreamCreate, + cudaStreamDestroy, + cudaStreamSynchronize, + cudaGetLastError, + cudaGetErrorString, + cudaGetErrorName +) + class CudaRuntimeError(RuntimeError): def __init__(self, extraMsg=None): - cdef _Error e = cudaGetLastError() + cdef cudaError_t e = cudaGetLastError() cdef bytes errMsg = cudaGetErrorString(e) cdef bytes errName = cudaGetErrorName(e) msg = "Error! %s reason='%s'" % (errName.decode(), errMsg.decode()) @@ -45,29 +57,17 @@ cdef class Stream: stream.sync() del stream # optional! """ - - # NOTE: - # If we store _Stream directly, this always leads to the following error: - # "Cannot convert Python object to '_Stream'" - # I was unable to find a good solution to this in reasonable time. Also, - # since cudaStream_t is a pointer anyways, storing it as an integer should - # be just fine (although, that certainly is ugly and hacky!). - cdef size_t s - def __cinit__(self): - if self.s != 0: - return - cdef _Stream stream - cdef _Error e = cudaStreamCreate(&stream) - if e != 0: + cdef cudaStream_t stream + cdef cudaError_t e = cudaStreamCreate(&stream) + if e != cudaSuccess: raise CudaRuntimeError("Stream create") - self.s = stream + self.s = stream def __dealloc__(self): self.sync() - cdef _Stream stream = <_Stream>self.s - cdef _Error e = cudaStreamDestroy(stream) - if e != 0: + cdef cudaError_t e = cudaStreamDestroy(self.s) + if e != cudaSuccess: raise CudaRuntimeError("Stream destroy") def sync(self): @@ -76,10 +76,9 @@ cdef class Stream: could raise exception due to issues with previous asynchronous launches """ - cdef _Stream stream = <_Stream>self.s - cdef _Error e = cudaStreamSynchronize(stream) - if e != 0: + cdef cudaError_t e = cudaStreamSynchronize(self.s) + if e != cudaSuccess: raise CudaRuntimeError("Stream sync") - def getStream(self): + cdef cudaStream_t getStream(self): return self.s diff --git a/python/raft/common/handle.pxd b/python/raft/common/handle.pxd index d2ae0a401d..8415b7e3d7 100644 --- a/python/raft/common/handle.pxd +++ b/python/raft/common/handle.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ from libcpp.memory cimport shared_ptr -from .cuda cimport _Stream from rmm._lib.cuda_stream_view cimport cuda_stream_view from rmm._lib.cuda_stream_pool cimport cuda_stream_pool from libcpp.memory cimport shared_ptr diff --git a/python/raft/common/handle.pyx b/python/raft/common/handle.pyx index 1accf9e679..661c5b5f23 100644 --- a/python/raft/common/handle.pyx +++ b/python/raft/common/handle.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,9 +24,10 @@ from libcpp.memory cimport shared_ptr from rmm._lib.cuda_stream_view cimport cuda_stream_per_thread from rmm._lib.cuda_stream_view cimport cuda_stream_view -from .cuda cimport _Stream, _Error, cudaStreamSynchronize +from .cuda cimport Stream from .cuda import CudaRuntimeError + cdef class Handle: """ Handle is a lightweight python wrapper around the corresponding C++ class @@ -51,7 +52,7 @@ cdef class Handle: del handle # optional! """ - def __cinit__(self, stream=None, n_streams=0): + def __cinit__(self, stream: Stream = None, n_streams=0): self.n_streams = n_streams if n_streams > 0: self.stream_pool.reset(new cuda_stream_pool(n_streams)) @@ -64,7 +65,7 @@ cdef class Handle: self.stream_pool)) else: # this constructor constructs a handle on user stream - c_stream = cuda_stream_view(<_Stream> stream.getStream()) + c_stream = cuda_stream_view(stream.getStream()) self.c_obj.reset(new handle_t(c_stream, self.stream_pool))