From 9fba2b7f2651dac121626ee12f76c3f14b17aefa Mon Sep 17 00:00:00 2001 From: ksimpson Date: Mon, 9 Dec 2024 14:22:58 -0800 Subject: [PATCH] use precondition and update test --- cuda_core/cuda/core/experimental/_module.py | 55 +++++++++++---------- cuda_core/tests/test_module.py | 49 ++++++------------ 2 files changed, 43 insertions(+), 61 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_module.py b/cuda_core/cuda/core/experimental/_module.py index bd38b913..e6a5e686 100644 --- a/cuda_core/cuda/core/experimental/_module.py +++ b/cuda_core/cuda/core/experimental/_module.py @@ -5,7 +5,7 @@ import importlib.metadata from cuda import cuda -from cuda.core.experimental._utils import handle_return +from cuda.core.experimental._utils import handle_return, precondition _backend = { "old": { @@ -127,31 +127,10 @@ def __init__(self, module, code_type, jit_options=None, *, symbol_mapping=None): self._sym_map = {} if symbol_mapping is None else symbol_mapping # TODO: do we want to unload in a finalizer? Probably not.. - - def get_kernel(self, name): - """Return the :obj:`Kernel` of a specified name from this object code. - - Parameters - ---------- - name : Any - Name of the kernel to retrieve. - - Returns - ------- - :obj:`Kernel` - Newly created kernel object. - - """ - try: - name = self._sym_map[name] - except KeyError: - name = name.encode() - - self._lazy_load_module() - data = handle_return(self._loader["kernel"](self._handle, name)) - return Kernel._from_obj(data, self) - - def _lazy_load_module(self): + + def _lazy_load_module(self, *args, **kwargs): + if self._handle is not None: + return if isinstance(self._module, str): # TODO: this option is only taken by the new library APIs, but we have # a bug that we can't easily support it just yet (NVIDIA/cuda-python#73). @@ -178,4 +157,28 @@ def _lazy_load_module(self): args = (self._module, len(self._jit_options), list(self._jit_options.keys()), list(self._jit_options.values())) self._handle = handle_return(self._loader["data"](*args)) + @precondition(_lazy_load_module) + def get_kernel(self, name): + """Return the :obj:`Kernel` of a specified name from this object code. + + Parameters + ---------- + name : Any + Name of the kernel to retrieve. + + Returns + ------- + :obj:`Kernel` + Newly created kernel object. + + """ + try: + name = self._sym_map[name] + except KeyError: + name = name.encode() + + data = handle_return(self._loader["kernel"](self._handle, name)) + return Kernel._from_obj(data, self) + + # TODO: implement from_handle() diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index e7fec356..f952542b 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -10,38 +10,17 @@ import pytest -from cuda.core.experimental._module import ObjectCode - - -@pytest.mark.skipif( - int(importlib.metadata.version("cuda-python").split(".")[0]) < 12, - reason="Module loading for older drivers validate require valid module code.", -) -def test_object_code_initialization(): - # Test with supported code types - for code_type in ["cubin", "ptx", "fatbin"]: - module_data = b"dummy_data" - obj_code = ObjectCode(module_data, code_type) - assert obj_code._code_type == code_type - assert obj_code._module == module_data - - # Test with unsupported code type - with pytest.raises(ValueError): - ObjectCode(b"dummy_data", "unsupported_code_type") - - -# TODO add ObjectCode tests which provide the appropriate data for cuLibraryLoadFromFile -def test_object_code_initialization_with_str(): - assert True - - -def test_object_code_initialization_with_jit_options(): - assert True - - -def test_object_code_get_kernel(): - assert True - - -def test_kernel_from_obj(): - assert True +from cuda.core.experimental import Program + + +def test_get_kernel(): + kernel = """ +extern __device__ int B(); +extern __device__ int C(int a, int b); +__global__ void A() { int result = C(B(), 1);} +""" + object_code = Program(kernel, "c++").compile("ptx", options=("-rdc=true",)) + assert object_code._handle is None + kernel = object_code.get_kernel("A") + assert object_code._handle is not None + assert kernel._handle is not None