diff --git a/python/cudf/cudf/core/udf/__init__.py b/python/cudf/cudf/core/udf/__init__.py index 3df1e0bd1d4..18bfc70decb 100644 --- a/python/cudf/cudf/core/udf/__init__.py +++ b/python/cudf/cudf/core/udf/__init__.py @@ -26,6 +26,10 @@ try: import strings_udf from strings_udf import ptxpath + + if ptxpath: + utils.ptx_files.append(ptxpath) + from strings_udf._lib.cudf_jit_udf import to_string_view_array from strings_udf._typing import str_view_arg_handler, string_view @@ -45,7 +49,7 @@ utils.launch_arg_getters[dtype("O")] = to_string_view_array utils.masked_array_types[dtype("O")] = string_view utils.JIT_SUPPORTED_TYPES |= STRING_TYPES - utils.ptx_files.append(ptxpath) + utils.arg_handlers.append(str_view_arg_handler) row_function.itemsizes[dtype("O")] = string_view.size_bytes diff --git a/python/strings_udf/strings_udf/__init__.py b/python/strings_udf/strings_udf/__init__.py index 970d40451d9..ee5d19c1ec0 100644 --- a/python/strings_udf/strings_udf/__init__.py +++ b/python/strings_udf/strings_udf/__init__.py @@ -1,9 +1,6 @@ # Copyright (c) 2022, NVIDIA CORPORATION. import glob import os -import re -import subprocess -import sys from cubinlinker.patch import _numba_version_ok, get_logger, new_patched_linker from numba import cuda @@ -44,7 +41,7 @@ def maybe_patch_numba_linker(driver_version): logger.debug("Cannot patch Numba Linker - unsupported version") -def get_ptx_file(): +def _get_ptx_file(): dev = cuda.get_current_device() # Load the highest compute capability file available that is less than @@ -56,6 +53,7 @@ def get_ptx_file(): "This strings_udf installation is missing the necessary PTX " "files. Please file an issue reporting this error and how you " "installed cudf and strings_udf." + "https://github.com/rapidsai/cudf/issues" ) regular_sms = [] @@ -84,8 +82,10 @@ def get_ptx_file(): return regular_result[1] +ptxpath = None versions = safe_get_versions() if not versions == NO_DRIVER: driver_version, runtime_version = versions maybe_patch_numba_linker(driver_version) - ptxpath = get_ptx_file() + if "RAPIDS_NO_INITIALIZE" not in os.environ: + ptxpath = _get_ptx_file()