Skip to content

Commit

Permalink
ENH deepcopy parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
dantegd committed Dec 20, 2024
1 parent f925864 commit 228097c
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions python/cuml/cuml/internals/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,20 @@

# distutils: language = c++

import copy
import os
import inspect
import numbers
import pickle
from importlib import import_module
from cuml.internals.device_support import GPU_ENABLED
from cuml.internals.safe_imports import (
cpu_only_import,
gpu_only_import_from,
null_decorator
null_decorator,
safe_import
)
joblib = safe_import(module="joblib")
np = cpu_only_import('numpy')
nvtx_annotate = gpu_only_import_from("nvtx", "annotate", alt=null_decorator)

Expand Down Expand Up @@ -852,7 +856,7 @@ class UniversalBase(Base):

raise ex

def as_sklearn(self):
def as_sklearn(self, deepcopy=False):
"""
Convert the current GPU-accelerated estimator into a scikit-learn estimator.
Expand All @@ -862,6 +866,15 @@ class UniversalBase(Base):
compatible scikit-learn estimator, allowing you to use it in standard
scikit-learn pipelines and workflows.
Parameters
----------
deepcopy : boolean (default=False)
Whether to return a deepcopy of the internal scikit-learn estimator of
the cuML models. cuML models internally have CPU based estimators that
could be updated. If you intend to use both the cuML and the scikit-learn
estimators after using the method in parallel, it is recommended to set
this to True to avoid one overwriting data of the other.
Returns
-------
sklearn.base.BaseEstimator
Expand All @@ -872,7 +885,10 @@ class UniversalBase(Base):
self.import_cpu_model()
self.build_cpu_model()
self.gpu_to_cpu()
return self._cpu_model
if deepcopy:
return copy.deepcopy(self._cpu_model)
else:
return self._cpu_model

@classmethod
def from_sklearn(cls, model):
Expand Down

0 comments on commit 228097c

Please sign in to comment.