diff --git a/python/cudf/cudf/utils/_ptxcompiler.py b/python/cudf/cudf/utils/_ptxcompiler.py index 54f5ea08ee1..9d7071d55a5 100644 --- a/python/cudf/cudf/utils/_ptxcompiler.py +++ b/python/cudf/cudf/utils/_ptxcompiler.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,11 +14,14 @@ import math import os +import re import subprocess import sys import warnings NO_DRIVER = (math.inf, math.inf) +START_TAG = "_VER_START" +END_TAG = "_VER_END" NUMBA_CHECK_VERSION_CMD = """\ from ctypes import c_int, byref @@ -28,7 +31,7 @@ drv_major = dv.value // 1000 drv_minor = (dv.value - (drv_major * 1000)) // 10 run_major, run_minor = cuda.runtime.get_version() -print(f'{drv_major} {drv_minor} {run_major} {run_minor}') +print(f'_VER_START{drv_major} {drv_minor} {run_major} {run_minor}_VER_END') """ @@ -61,7 +64,11 @@ def get_versions(): warnings.warn(msg, UserWarning) return NO_DRIVER - versions = [int(s) for s in cp.stdout.strip().split()] + pattern = r"_VER_START(.*?)_VER_END" + + ver_str = re.search(pattern, cp.stdout.decode()).group(1) + + versions = [int(s) for s in ver_str.strip().split()] driver_version = tuple(versions[:2]) runtime_version = tuple(versions[2:])