Skip to content

Commit

Permalink
fix up reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-b-miller committed Nov 1, 2022
1 parent 63992b1 commit 777e257
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
6 changes: 5 additions & 1 deletion python/cudf/cudf/core/udf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
try:
import strings_udf
from strings_udf import ptxpath

if ptxpath:
utils.ptx_files.append(ptxpath)

from strings_udf._lib.cudf_jit_udf import to_string_view_array
from strings_udf._typing import str_view_arg_handler, string_view

Expand All @@ -45,7 +49,7 @@
utils.launch_arg_getters[dtype("O")] = to_string_view_array
utils.masked_array_types[dtype("O")] = string_view
utils.JIT_SUPPORTED_TYPES |= STRING_TYPES
utils.ptx_files.append(ptxpath)

utils.arg_handlers.append(str_view_arg_handler)
row_function.itemsizes[dtype("O")] = string_view.size_bytes

Expand Down
10 changes: 5 additions & 5 deletions python/strings_udf/strings_udf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
import glob
import os
import re
import subprocess
import sys

from cubinlinker.patch import _numba_version_ok, get_logger, new_patched_linker
from numba import cuda
Expand Down Expand Up @@ -44,7 +41,7 @@ def maybe_patch_numba_linker(driver_version):
logger.debug("Cannot patch Numba Linker - unsupported version")


def get_ptx_file():
def _get_ptx_file():
dev = cuda.get_current_device()

# Load the highest compute capability file available that is less than
Expand All @@ -56,6 +53,7 @@ def get_ptx_file():
"This strings_udf installation is missing the necessary PTX "
"files. Please file an issue reporting this error and how you "
"installed cudf and strings_udf."
"https://github.com/rapidsai/cudf/issues"
)

regular_sms = []
Expand Down Expand Up @@ -84,8 +82,10 @@ def get_ptx_file():
return regular_result[1]


ptxpath = None
versions = safe_get_versions()
if not versions == NO_DRIVER:
driver_version, runtime_version = versions
maybe_patch_numba_linker(driver_version)
ptxpath = get_ptx_file()
if "RAPIDS_NO_INITIALIZE" not in os.environ:
ptxpath = _get_ptx_file()

0 comments on commit 777e257

Please sign in to comment.