Skip to content

Commit

Permalink
Merge pull request rapidsai#5926 from rapidsai/branch-24.06
Browse files Browse the repository at this point in the history
Forward-merge branch-24.06 into branch-24.08
  • Loading branch information
GPUtester authored Jun 12, 2024
2 parents 604bf42 + 45d963a commit 2aac623
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 8 deletions.
15 changes: 14 additions & 1 deletion python/cuml/internals/device_support.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,11 +15,24 @@
#


from packaging.version import Version

MIN_SKLEARN_VERSION = Version('1.5')


try:
import sklearn # noqa: F401 # no-cython-lint

CPU_ENABLED = True

if(Version(sklearn.__version__) >= MIN_SKLEARN_VERSION):
MIN_SKLEARN_PRESENT = (True, None, None)
else:
MIN_SKLEARN_PRESENT = (False, sklearn.__version__, MIN_SKLEARN_VERSION)

except ImportError:
CPU_ENABLED = False
MIN_SKLEARN_PRESENT = (False, None, None)

IF GPUBUILD == 1:
GPU_ENABLED = True
Expand Down
48 changes: 41 additions & 7 deletions python/cuml/internals/safe_imports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,9 +17,15 @@

import importlib
import traceback
import warnings

from contextlib import contextmanager
from cuml.internals.device_support import CPU_ENABLED, GPU_ENABLED
from cuml.internals.device_support import (
CPU_ENABLED,
GPU_ENABLED,
MIN_SKLEARN_VERSION,
MIN_SKLEARN_PRESENT,
)
from cuml.internals import logger


Expand Down Expand Up @@ -429,12 +435,25 @@ def cpu_only_import(module, *, alt=None):
The imported module, the given alternate, or a class derived from
UnavailableMeta.
"""
if CPU_ENABLED:
if CPU_ENABLED and MIN_SKLEARN_PRESENT[0]:
return importlib.import_module(module)

else:
if CPU_ENABLED:
err_msg = (
"Installed version of Scikit-learn {} "
"is lower than the latest tested and supported "
"version {}. This can affect the functionality "
"of some CPU components of cuML, GPU estimators "
"are unaffected.".format(
MIN_SKLEARN_PRESENT[1], MIN_SKLEARN_PRESENT[2]
)
)
else:
err_msg = f"{module} is not installed in GPU-only installations"
return safe_import(
module,
msg=f"{module} is not installed in GPU-only installations",
msg=err_msg,
alt=alt,
)

Expand Down Expand Up @@ -467,14 +486,29 @@ def cpu_only_import_from(module, symbol, *, alt=None):
The imported symbol, the given alternate, or a class derived from
UnavailableMeta.
"""
if CPU_ENABLED:
if CPU_ENABLED and MIN_SKLEARN_PRESENT[0]:
imported_module = importlib.import_module(module)
return getattr(imported_module, symbol)
else:
if CPU_ENABLED:
err_msg = (
"Installed version of Scikit-learn {} "
"is lower than the latest tested and supported "
"version {}. This can affect the functionality "
"of some CPU components of cuML, GPU estimators "
"are unaffected.".format(
MIN_SKLEARN_PRESENT[1], MIN_SKLEARN_PRESENT[2]
)
)
else:
err_msg = (
f"{module}.{symbol} is not installed in GPU-only "
"installations"
)

return safe_import_from(
module,
symbol,
msg=f"{module}.{symbol} is not available in GPU-only"
" installations",
msg=err_msg,
alt=alt,
)

0 comments on commit 2aac623

Please sign in to comment.