From 421fdfbefd01acc86022aa15344cab7f90ed38b0 Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Thu, 22 Aug 2024 02:16:52 -0700 Subject: [PATCH] continue addressing reviews --- numba_cuda/numba/cuda/cudadrv/driver.py | 24 +++++++----------------- numba_cuda/numba/cuda/cudadrv/enums.py | 2 +- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py index 9efac41..6a16403 100644 --- a/numba_cuda/numba/cuda/cudadrv/driver.py +++ b/numba_cuda/numba/cuda/cudadrv/driver.py @@ -2695,7 +2695,11 @@ def add_cu_file(self, path): self.add_cu(cu, os.path.basename(path)) def add_file_guess_ext(self, path_or_code): - """Add a file to the link, guessing its type from its extension.""" + """ + Add a file or LinkableCode object to the link. If a file is + passed, the type will be inferred from the extension. A LinkableCode + object represents a file already in memory. + """ if isinstance(path_or_code, str): ext = pathlib.Path(path_or_code).suffix if ext == '': @@ -2704,6 +2708,8 @@ def add_file_guess_ext(self, path_or_code): ) elif ext == '.cu': self.add_cu_file(path_or_code) + elif ext == ".ltoir": + self.add_file(path_or_code, "ltoir") else: kind = FILE_EXTENSION_MAP.get(ext, None) if kind is None: @@ -3093,22 +3099,6 @@ def add_data(self, data, kind, name): except NvJitLinkError as e: raise LinkerError from e - def add_cu(self, cu, name): - with driver.get_active_context() as ac: - dev = driver.get_device(ac.devnum) - cc = dev.compute_capability - - ptx, log = nvrtc.compile(cu, name, cc) - - if config.DUMP_ASSEMBLY: - print(("ASSEMBLY %s" % name).center(80, "-")) - print(ptx) - print("=" * 80) - - # Link the program's PTX using the normal linker mechanism - ptx_name = os.path.splitext(name)[0] + ".ptx" - self.add_ptx(ptx.encode(), ptx_name) - def complete(self): try: cubin = self._linker.get_linked_cubin() diff --git a/numba_cuda/numba/cuda/cudadrv/enums.py b/numba_cuda/numba/cuda/cudadrv/enums.py index 917dbb1..25bbbe1 100644 --- a/numba_cuda/numba/cuda/cudadrv/enums.py +++ b/numba_cuda/numba/cuda/cudadrv/enums.py @@ -312,7 +312,7 @@ # LTO IR CU_JIT_INPUT_LTO_IR = 5 -CU_JIT_NUM_INPUT_TYPES = 7 +CU_JIT_NUM_INPUT_TYPES = 6 # Online compiler and linker options