Skip to content

Commit

Permalink
comments and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
isVoid committed Dec 16, 2024
1 parent 5273e4a commit bb1cf0f
Showing 1 changed file with 75 additions and 13 deletions.
88 changes: 75 additions & 13 deletions numba_cuda/numba/cuda/runtime/nrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
from numba.cuda.utils import _readenv


# Check environment variable or config for NRT statistics enablement
NRT_STATS = (
_readenv("NUMBA_CUDA_NRT_STATS", bool, False) or
getattr(config, "NUMBA_CUDA_NRT_STATS", False)
)
if not hasattr(config, "NUMBA_CUDA_NRT_STATS"):
config.CUDA_NRT_STATS = NRT_STATS


# Check environment variable or config for NRT enablement
ENABLE_NRT = (
_readenv("NUMBA_CUDA_ENABLE_NRT", bool, False) or
getattr(config, "NUMBA_CUDA_ENABLE_NRT", False)
Expand All @@ -25,7 +28,11 @@
config.CUDA_ENABLE_NRT = ENABLE_NRT


# Protect method to ensure NRT memory allocation and initialization
def _alloc_init_guard(method):
"""
Ensure NRT memory allocation and initialization before running the method
"""
@wraps(method)
def wrapper(self, *args, **kwargs):
self.ensure_allocated()
Expand All @@ -35,6 +42,7 @@ def wrapper(self, *args, **kwargs):


class _Runtime:
"""Singleton class for Numba CUDA runtime"""
_instance = None

def __new__(cls, *args, **kwargs):
Expand All @@ -43,47 +51,64 @@ def __new__(cls, *args, **kwargs):
return cls._instance

def __init__(self):
"""Initialize memsys module and variable"""
self._memsys_module = None
self._memsys = None

self._initialized = False

def _compile_memsys_module(self):
"""
Compile memsys.cu and create a module from it in the current context
"""
# Define the path for memsys.cu
memsys_mod = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"memsys.cu"
)
cc = get_current_device().compute_capability

# Create a new linker instance and add the cu file
linker = Linker.new(cc=cc)
linker.add_cu_file(memsys_mod)
cubin = linker.complete()

# Complete the linker and create a module from it
cubin = linker.complete()
ctx = devices.get_context()
module = ctx.create_module_image(cubin)

# Set the memsys module
self._memsys_module = module

def ensure_allocated(self, stream=None):
"""
If memsys is not allocated, allocate it; otherwise, perform a no-op
"""
if self._memsys is not None:
return

# Allocate the memsys
self.allocate(stream)

def allocate(self, stream=None):
"""
Allocate memsys on global memory
"""
from numba.cuda import device_array

# Check if memsys module is defined
if self._memsys_module is None:
# Compile the memsys module if not defined
self._compile_memsys_module()

# Allocate space for NRT_MemSys
# TODO: determine the size of NRT_MemSys at runtime
self._memsys = device_array((40,), dtype="i1", stream=stream)
# TODO: Memsys module needs a stream that's consistent with the
# system's stream.
self.set_memsys_to_module(self._memsys_module, stream=stream)

def _single_thread_launch(self, module, stream, name, params=()):
"""
Launch the specified kernel with only 1 thread
"""
if stream is None:
stream = cuda.default_stream()

Expand All @@ -99,33 +124,45 @@ def _single_thread_launch(self, module, stream, name, params=()):
)

def ensure_initialized(self, stream=None):
"""
If memsys is not initialized, initialize memsys
"""
if self._initialized:
return

# Initialize the memsys
self.initialize(stream)

def initialize(self, stream=None):
self.ensure_allocated(stream)

"""
Launch memsys initialization kernel
"""
self._single_thread_launch(
self._memsys_module, stream, "NRT_MemSys_init")
self._initialized = True

if NRT_STATS:
self.memsys_enable_stats(stream)

@_alloc_init_guard
def memsys_enable_stats(self, stream=None):
"""
Enable memsys statistics
"""
self._single_thread_launch(
self._memsys_module, stream, "NRT_MemSys_enable_stats")

@_alloc_init_guard
def memsys_disable_stats(self, stream=None):
"""
Disable memsys statistics
"""
self._single_thread_launch(
self._memsys_module, stream, "NRT_MemSys_disable_stats")

@_alloc_init_guard
def memsys_stats_enabled(self, stream=None):
"""
Return a boolean indicating whether memsys is enabled. Synchronizes
context
"""
enabled_ar = cuda.managed_array(1, np.uint8)

self._single_thread_launch(
Expand All @@ -140,8 +177,9 @@ def memsys_stats_enabled(self, stream=None):

@_alloc_init_guard
def _copy_memsys_to_host(self, stream):

# Q: What stream should we execute this on?
"""
Copy all statistics of memsys to the host
"""
dt = np.dtype([
('alloc', np.uint64),
('free', np.uint64),
Expand All @@ -163,6 +201,9 @@ def _copy_memsys_to_host(self, stream):

@_alloc_init_guard
def get_allocation_stats(self, stream=None):
"""
Get the allocation statistics
"""
enabled = self.memsys_stats_enabled(stream)
if not enabled:
raise RuntimeError("NRT stats are disabled.")
Expand All @@ -176,6 +217,9 @@ def get_allocation_stats(self, stream=None):

@_alloc_init_guard
def _get_single_stat(self, stat, stream=None):
"""
Get a single stat from the memsys
"""
got = cuda.managed_array(1, np.uint64)
self._single_thread_launch(
self._memsys_module,
Expand All @@ -189,6 +233,9 @@ def _get_single_stat(self, stat, stream=None):

@_alloc_init_guard
def memsys_get_stats_alloc(self, stream=None):
"""
Get the allocation statistic
"""
enabled = self.memsys_stats_enabled(stream)
if not enabled:
raise RuntimeError("NRT stats are disabled.")
Expand All @@ -197,6 +244,9 @@ def memsys_get_stats_alloc(self, stream=None):

@_alloc_init_guard
def memsys_get_stats_free(self, stream=None):
"""
Get the free statistic
"""
enabled = self.memsys_stats_enabled(stream)
if not enabled:
raise RuntimeError("NRT stats are disabled.")
Expand All @@ -205,6 +255,9 @@ def memsys_get_stats_free(self, stream=None):

@_alloc_init_guard
def memsys_get_stats_mi_alloc(self, stream=None):
"""
Get the mi alloc statistic
"""
enabled = self.memsys_stats_enabled(stream)
if not enabled:
raise RuntimeError("NRT stats are disabled.")
Expand All @@ -213,18 +266,24 @@ def memsys_get_stats_mi_alloc(self, stream=None):

@_alloc_init_guard
def memsys_get_stats_mi_free(self, stream=None):
"""
Get the mi free statistic
"""
enabled = self.memsys_stats_enabled(stream)
if not enabled:
raise RuntimeError("NRT stats are disabled.")

return self._get_single_stat("mi_free")

def set_memsys_to_module(self, module, stream=None):
"""
Set the memsys module. The module must contain `NRT_MemSys_set` kernel,
and declare a pointer to NRT_MemSys structure.
"""
if self._memsys is None:
raise RuntimeError(
"Please allocate NRT Memsys first before initializing.")

print(f"Setting {self._memsys.device_ctypes_pointer} to {module}")
self._single_thread_launch(
module,
stream,
Expand All @@ -234,7 +293,9 @@ def set_memsys_to_module(self, module, stream=None):

@_alloc_init_guard
def print_memsys(self, stream=None):
"""Print the current statistics of memsys, for debugging purpose."""
"""
Print the current statistics of memsys, for debugging purposes
"""
cuda.synchronize()
self._single_thread_launch(
self._memsys_module,
Expand All @@ -243,4 +304,5 @@ def print_memsys(self, stream=None):
)


# Create an instance of the runtime
rtsys = _Runtime()

0 comments on commit bb1cf0f

Please sign in to comment.