diff --git a/python/cudf/cudf/utils/_numba.py b/python/cudf/cudf/utils/_numba.py index fb96ad81941..10041253dac 100644 --- a/python/cudf/cudf/utils/_numba.py +++ b/python/cudf/cudf/utils/_numba.py @@ -6,9 +6,6 @@ from functools import lru_cache from numba import config as numba_config -from pynvjitlink.patch import ( - patch_numba_linker as patch_numba_linker_pynvjitlink, -) # Use an lru_cache with a single value to allow a delayed import of @@ -135,7 +132,9 @@ def _setup_numba(): if driver_version < (12, 0): patch_numba_linker_cuda_11() else: - patch_numba_linker_pynvjitlink() + from pynvjitlink.patch import patch_numba_linker + + patch_numba_linker() def _get_cuda_version_from_ptx_file(path):