Skip to content

Commit

Permalink
Fix for creation of CUDA context at import time (#5211)
Browse files Browse the repository at this point in the history
closes issue #5206

Authors:
   - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
   - Corey J. Nolet (https://github.com/cjnolet)
  • Loading branch information
dantegd authored Feb 7, 2023
1 parent 8abfd64 commit 277f2da
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 85 deletions.
162 changes: 80 additions & 82 deletions python/cuml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,91 +15,89 @@
#

from cuml.internals.base import Base, UniversalBase
from cuml.internals.available_devices import is_cuda_available

# GPU only packages

if (is_cuda_available()):
import cuml.common.cuda as cuda
from cuml.common.handle import Handle

from cuml.cluster.dbscan import DBSCAN
from cuml.cluster.kmeans import KMeans
from cuml.cluster.agglomerative import AgglomerativeClustering
from cuml.cluster.hdbscan import HDBSCAN

from cuml.datasets.arima import make_arima
from cuml.datasets.blobs import make_blobs
from cuml.datasets.regression import make_regression
from cuml.datasets.classification import make_classification

from cuml.decomposition.pca import PCA
from cuml.decomposition.tsvd import TruncatedSVD
from cuml.decomposition.incremental_pca import IncrementalPCA

from cuml.fil.fil import ForestInference

from cuml.ensemble.randomforestclassifier import RandomForestClassifier
from cuml.ensemble.randomforestregressor import RandomForestRegressor

from cuml.explainer.kernel_shap import KernelExplainer
from cuml.explainer.permutation_shap import PermutationExplainer
from cuml.explainer.tree_shap import TreeExplainer

import cuml.feature_extraction
from cuml.fil import fil

from cuml.internals.global_settings import (
GlobalSettings, _global_settings_data)

from cuml.kernel_ridge.kernel_ridge import KernelRidge

from cuml.linear_model.elastic_net import ElasticNet
from cuml.linear_model.lasso import Lasso
from cuml.linear_model.logistic_regression import LogisticRegression
from cuml.linear_model.mbsgd_classifier import MBSGDClassifier
from cuml.linear_model.mbsgd_regressor import MBSGDRegressor
from cuml.linear_model.ridge import Ridge

from cuml.manifold.t_sne import TSNE
from cuml.manifold.umap import UMAP
from cuml.metrics.accuracy import accuracy_score
from cuml.metrics.cluster.adjusted_rand_index import adjusted_rand_score
from cuml.metrics.regression import r2_score
from cuml.model_selection import train_test_split

from cuml.naive_bayes.naive_bayes import MultinomialNB

from cuml.neighbors.nearest_neighbors import NearestNeighbors
from cuml.neighbors.kernel_density import KernelDensity
from cuml.neighbors.kneighbors_classifier import KNeighborsClassifier
from cuml.neighbors.kneighbors_regressor import KNeighborsRegressor

from cuml.preprocessing.LabelEncoder import LabelEncoder

from cuml.random_projection.random_projection import \
GaussianRandomProjection
from cuml.random_projection.random_projection import SparseRandomProjection
from cuml.random_projection.random_projection import \
johnson_lindenstrauss_min_dim

from cuml.solvers.cd import CD
from cuml.solvers.sgd import SGD
from cuml.solvers.qn import QN
from cuml.svm import SVC
from cuml.svm import SVR
from cuml.svm import LinearSVC
from cuml.svm import LinearSVR

from cuml.tsa import stationarity
from cuml.tsa.arima import ARIMA
from cuml.tsa.auto_arima import AutoARIMA
from cuml.tsa.holtwinters import ExponentialSmoothing

from cuml.common.pointer_utils import device_of_gpu_matrix
from cuml.internals.memory_utils import (
set_global_output_type, using_output_type
)
import cuml.common.cuda as cuda
from cuml.common.handle import Handle

from cuml.cluster.dbscan import DBSCAN
from cuml.cluster.kmeans import KMeans
from cuml.cluster.agglomerative import AgglomerativeClustering
from cuml.cluster.hdbscan import HDBSCAN

from cuml.datasets.arima import make_arima
from cuml.datasets.blobs import make_blobs
from cuml.datasets.regression import make_regression
from cuml.datasets.classification import make_classification

from cuml.decomposition.pca import PCA
from cuml.decomposition.tsvd import TruncatedSVD
from cuml.decomposition.incremental_pca import IncrementalPCA

from cuml.fil.fil import ForestInference

from cuml.ensemble.randomforestclassifier import RandomForestClassifier
from cuml.ensemble.randomforestregressor import RandomForestRegressor

from cuml.explainer.kernel_shap import KernelExplainer
from cuml.explainer.permutation_shap import PermutationExplainer
from cuml.explainer.tree_shap import TreeExplainer

import cuml.feature_extraction
from cuml.fil import fil

from cuml.internals.global_settings import (
GlobalSettings, _global_settings_data)

from cuml.kernel_ridge.kernel_ridge import KernelRidge

from cuml.linear_model.elastic_net import ElasticNet
from cuml.linear_model.lasso import Lasso
from cuml.linear_model.logistic_regression import LogisticRegression
from cuml.linear_model.mbsgd_classifier import MBSGDClassifier
from cuml.linear_model.mbsgd_regressor import MBSGDRegressor
from cuml.linear_model.ridge import Ridge

from cuml.manifold.t_sne import TSNE
from cuml.manifold.umap import UMAP
from cuml.metrics.accuracy import accuracy_score
from cuml.metrics.cluster.adjusted_rand_index import adjusted_rand_score
from cuml.metrics.regression import r2_score
from cuml.model_selection import train_test_split

from cuml.naive_bayes.naive_bayes import MultinomialNB

from cuml.neighbors.nearest_neighbors import NearestNeighbors
from cuml.neighbors.kernel_density import KernelDensity
from cuml.neighbors.kneighbors_classifier import KNeighborsClassifier
from cuml.neighbors.kneighbors_regressor import KNeighborsRegressor

from cuml.preprocessing.LabelEncoder import LabelEncoder

from cuml.random_projection.random_projection import \
GaussianRandomProjection
from cuml.random_projection.random_projection import SparseRandomProjection
from cuml.random_projection.random_projection import \
johnson_lindenstrauss_min_dim

from cuml.solvers.cd import CD
from cuml.solvers.sgd import SGD
from cuml.solvers.qn import QN
from cuml.svm import SVC
from cuml.svm import SVR
from cuml.svm import LinearSVC
from cuml.svm import LinearSVR

from cuml.tsa import stationarity
from cuml.tsa.arima import ARIMA
from cuml.tsa.auto_arima import AutoARIMA
from cuml.tsa.holtwinters import ExponentialSmoothing

from cuml.common.pointer_utils import device_of_gpu_matrix
from cuml.internals.memory_utils import (
set_global_output_type, using_output_type
)

# Universal packages

Expand Down
8 changes: 6 additions & 2 deletions python/cuml/internals/array_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ class SparseCumlArray():
domain="cuml_python")
def __init__(self, data=None,
convert_to_dtype=False,
convert_to_mem_type=GlobalSettings().memory_type,
convert_index=GlobalSettings().xpy.int32,
convert_to_mem_type=None,
convert_index=None,
convert_format=True):
is_sparse = False
try:
Expand Down Expand Up @@ -132,12 +132,16 @@ def __init__(self, data=None,

if convert_to_mem_type:
convert_to_mem_type = MemoryType.from_str(convert_to_mem_type)
else:
convert_to_mem_type = GlobalSettings().memory_type

if convert_to_mem_type is MemoryType.mirror or not convert_to_mem_type:
convert_to_mem_type = from_mem_type

self._mem_type = convert_to_mem_type

if convert_index is None:
convert_index = GlobalSettings().xpy.int32
if not convert_index:
convert_index = data.indptr.dtype

Expand Down
1 change: 0 additions & 1 deletion python/cuml/neighbors/kneighbors_classifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ from libc.stdlib cimport calloc, malloc, free

from cuml.internals.safe_imports import gpu_only_import_from
cuda = gpu_only_import_from('numba', 'cuda')
import rmm

cimport cuml.common.cuda

Expand Down

0 comments on commit 277f2da

Please sign in to comment.