Skip to content

Commit

Permalink
More safely parse CUDA versions when subprocess output is contaminated (
Browse files Browse the repository at this point in the history
#16067)

In some user environments, calling a subprocess may produce output that confuses the version parsing machinery inside `_ptxcompiler`. Since the affected functions are vendored from the real `ptxcompiler` package for the purposes of using them with CUDA 12, this fix will only these situations for CUDA 12+.

Closes #16016.

Authors:
  - https://github.com/brandon-b-miller

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #16067
  • Loading branch information
brandon-b-miller authored Jun 24, 2024
1 parent 9987410 commit f583879
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions python/cudf/cudf/utils/_ptxcompiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,11 +14,14 @@

import math
import os
import re
import subprocess
import sys
import warnings

NO_DRIVER = (math.inf, math.inf)
START_TAG = "_VER_START"
END_TAG = "_VER_END"

NUMBA_CHECK_VERSION_CMD = """\
from ctypes import c_int, byref
Expand All @@ -28,7 +31,7 @@
drv_major = dv.value // 1000
drv_minor = (dv.value - (drv_major * 1000)) // 10
run_major, run_minor = cuda.runtime.get_version()
print(f'{drv_major} {drv_minor} {run_major} {run_minor}')
print(f'_VER_START{drv_major} {drv_minor} {run_major} {run_minor}_VER_END')
"""


Expand Down Expand Up @@ -61,7 +64,11 @@ def get_versions():
warnings.warn(msg, UserWarning)
return NO_DRIVER

versions = [int(s) for s in cp.stdout.strip().split()]
pattern = r"_VER_START(.*?)_VER_END"

ver_str = re.search(pattern, cp.stdout.decode()).group(1)

versions = [int(s) for s in ver_str.strip().split()]
driver_version = tuple(versions[:2])
runtime_version = tuple(versions[2:])

Expand Down

0 comments on commit f583879

Please sign in to comment.