Skip to content

Commit

Permalink
Explicitly control the use of stream in tests with NRT libraries
Browse files Browse the repository at this point in the history
  • Loading branch information
isVoid committed Dec 5, 2024
1 parent 06f5e53 commit f4d1a80
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
9 changes: 6 additions & 3 deletions numba_cuda/numba/cuda/runtime/nrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,16 @@ def allocate(self, stream):
self.set_memsys_to_module(self._memsys_module, stream=stream)

def _single_thread_launch(self, module, stream, name, params=()):
if stream is None:
stream = cuda.default_stream()

func = module.get_function(name)
launch_kernel(
func.handle,
1, 1, 1,
1, 1, 1,
0,
stream,
stream.handle,
params,
cooperative=False
)
Expand Down Expand Up @@ -92,7 +95,7 @@ def memsys_stats_disabled(self, stream):
self._single_thread_launch(
self._memsys_module, stream, "NRT_MemSys_disable")

def _copy_memsys_to_host(self, stream=0):
def _copy_memsys_to_host(self, stream):
self.ensure_allocate(stream)
self.ensure_initialize(stream)

Expand All @@ -116,7 +119,7 @@ def _copy_memsys_to_host(self, stream=0):

return stats_for_read[0]

def get_allocation_stats(self, stream=0):
def get_allocation_stats(self, stream):
memsys = self._copy_memsys_to_host(stream)
return _nrt_mstats(
alloc=memsys["alloc"],
Expand Down
33 changes: 18 additions & 15 deletions numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from numba.cuda.runtime import rtsys
from numba.tests.support import EnableNRTStatsMixin
from numba.cuda.testing import CUDATestCase

from numba.cuda.tests.nrt.mock_numpy import cuda_empty, cuda_empty_like

from numba import cuda
Expand All @@ -25,17 +24,17 @@ def test_no_return(self):
"""
n = 10

@cuda.jit(debug=True)
@cuda.jit
def kernel():
for i in range(n):
temp = cuda_empty(2, np.float64) # noqa: F841
return None

init_stats = rtsys.get_allocation_stats()

stream = cuda.default_stream()
init_stats = rtsys.get_allocation_stats(stream)
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
kernel[1,1]()
cur_stats = rtsys.get_allocation_stats()
kernel[1, 1, stream]()
cur_stats = rtsys.get_allocation_stats(stream)
self.assertEqual(cur_stats.alloc - init_stats.alloc, n)
self.assertEqual(cur_stats.free - init_stats.free, n)

Expand All @@ -57,10 +56,11 @@ def g(n):

return None

init_stats = rtsys.get_allocation_stats()
stream = cuda.default_stream()
init_stats = rtsys.get_allocation_stats(stream)
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
g[1, 1](10)
cur_stats = rtsys.get_allocation_stats()
g[1, 1, stream](10)
cur_stats = rtsys.get_allocation_stats(stream)
self.assertEqual(cur_stats.alloc - init_stats.alloc, 1)
self.assertEqual(cur_stats.free - init_stats.free, 1)

Expand All @@ -80,10 +80,11 @@ def if_with_allocation_and_initialization(arr1, test1):

arr = np.random.random((5, 5)) # the values are not consumed

init_stats = rtsys.get_allocation_stats()
stream = cuda.default_stream()
init_stats = rtsys.get_allocation_stats(stream)
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
if_with_allocation_and_initialization[1, 1](arr, False)
cur_stats = rtsys.get_allocation_stats()
if_with_allocation_and_initialization[1, 1, stream](arr, False)
cur_stats = rtsys.get_allocation_stats(stream)
self.assertEqual(cur_stats.alloc - init_stats.alloc,
cur_stats.free - init_stats.free)

Expand All @@ -103,10 +104,12 @@ def f(arr):
res += t[i]

arr = np.ones((2, 2))
init_stats = rtsys.get_allocation_stats()

stream = cuda.default_stream()
init_stats = rtsys.get_allocation_stats(stream)
with patch('numba.config.CUDA_ENABLE_NRT', True, create=True):
f[1, 1](arr)
cur_stats = rtsys.get_allocation_stats()
f[1, 1, stream](arr)
cur_stats = rtsys.get_allocation_stats(stream)
self.assertEqual(cur_stats.alloc - init_stats.alloc,
cur_stats.free - init_stats.free)

Expand Down

0 comments on commit f4d1a80

Please sign in to comment.