Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved CPU/GPU interoperability #5001

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
292 commits
Select commit Hold shift + click to select a range
e14fee8
Update array import path
wphicks Oct 4, 2022
e3d0737
Update input_utils import path
wphicks Oct 4, 2022
0072c2a
Move type_utils to internals
wphicks Oct 4, 2022
c22fcec
Remove unused imports
wphicks Oct 4, 2022
2214d4e
Move import_utils to internals
wphicks Oct 4, 2022
5c03489
Fix tests
viclafargue Oct 5, 2022
09db802
Using raft::KeyValuePair instead of cub::KeyValuePair
viclafargue Oct 5, 2022
57d27e9
Begin adjusting CumlArray for host inputs
wphicks Oct 5, 2022
0dbe8a4
Revert "Using raft::KeyValuePair instead of cub::KeyValuePair"
cjnolet Oct 6, 2022
495c676
Fix _check_internal_model
viclafargue Oct 7, 2022
73bb16c
Generic testing
viclafargue Oct 7, 2022
a232ff6
Adding UMAP
viclafargue Oct 7, 2022
edc4d84
Adding LogisticRegression
viclafargue Oct 7, 2022
3c3328e
Merge branch 'branch-22.10' into cpu-gpu-interop-models
viclafargue Oct 10, 2022
9d83c4b
Add error checking for array construction
wphicks Oct 10, 2022
7b952d6
Finish initial to_output implementation for CumlArray
wphicks Oct 11, 2022
fb08a3f
Update CumlArray construction methods for mem_type
wphicks Oct 11, 2022
e5f88cf
Adding LogisticRegression 2
viclafargue Oct 11, 2022
1ac78ed
Use consistent methods for detecting mem accessibility
wphicks Oct 11, 2022
df4bc5b
Handle backward-compatibility for CumlArray to_output
wphicks Oct 11, 2022
4a79bea
Move logger to internals
wphicks Oct 11, 2022
20b55d4
Guard imports in input_utils
wphicks Oct 11, 2022
ac67b32
Begin refactoring input_utils with optional dependencies
wphicks Oct 12, 2022
cda3cc8
Revert "Move logger to internals"
wphicks Oct 12, 2022
5da57ee
Move logger.pyx to internals
wphicks Oct 12, 2022
97ba765
Update logger import path
wphicks Oct 12, 2022
3c3794a
Update remaining conversions in input_utils
wphicks Oct 12, 2022
259b1be
Reimplement is_array_like
wphicks Oct 12, 2022
668dfb3
Avoid short-circuiting check for is_array_like
wphicks Oct 12, 2022
114be42
Implement determine_array_memtype
wphicks Oct 12, 2022
0b8f673
Update convert_dtype utility
wphicks Oct 12, 2022
2ab5ed2
Update base.pyx imports
wphicks Oct 12, 2022
2f9b2ac
Improved estimator initialization
viclafargue Oct 13, 2022
26f1fd1
Checker for similarity of hyperparameters default values
viclafargue Oct 13, 2022
a5a8834
Update Base classes for safe import
wphicks Oct 13, 2022
5bd3def
Separate DeviceType and MemType definitions
wphicks Oct 13, 2022
9ea097b
Make array_sparse CPU-safe
wphicks Oct 13, 2022
c5d78c7
Break circular import
wphicks Oct 13, 2022
d2a857a
Hyperparameters check as a test
viclafargue Oct 14, 2022
c594329
Move mixins.py to internals
wphicks Oct 14, 2022
ef54431
Break circular imports in mixins.py
wphicks Oct 14, 2022
0c37722
info instead of warn
viclafargue Oct 14, 2022
d1b9a3d
Merge branch 'branch-22.12' into fea-xpu_infra
wphicks Oct 14, 2022
a8b25be
Break circular import for base return type introspection
wphicks Oct 14, 2022
9452992
Merge branch 'branch-22.12' into fea-xpu_infra
dantegd Oct 17, 2022
b6d5db9
FIX FIrst round of import fixes
dantegd Oct 17, 2022
d9df07b
FIX array desc import update
dantegd Oct 17, 2022
a663f85
FIX import logger fix
dantegd Oct 18, 2022
cda6ecf
FIX another round of import updates
dantegd Oct 18, 2022
d91d924
Add Lasso, ElasticNet, Ridge
viclafargue Oct 19, 2022
49a7c3c
Add PCA
viclafargue Oct 20, 2022
4594f3a
Add TSVD
viclafargue Oct 20, 2022
8cdecb7
Merge branch-22.10
viclafargue Oct 21, 2022
94e9806
FEA CMake changes checkpoint 1
dantegd Oct 22, 2022
b9293a5
Merge 22.12 into dev-xpu_infra
dantegd Oct 22, 2022
90a03e5
FEA cmake add module function in linear_model
dantegd Oct 24, 2022
6b37bd9
FEA additional changes for universal linear reg
dantegd Oct 24, 2022
523aaed
Add NearestNeighbors + improve CumlArrayDescriptor
viclafargue Oct 26, 2022
b4584fe
Use of the same attributes
viclafargue Oct 26, 2022
2552b7c
Dev xpu infra (#7)
dantegd Oct 26, 2022
cac70e3
Correct mem_type handling in tests
wphicks Oct 27, 2022
570e03b
Correct order detection
wphicks Oct 27, 2022
53a601f
Correct shape checking
wphicks Oct 27, 2022
4ba121a
Correct handling of bytes input
wphicks Oct 27, 2022
1650a82
Improve coverage for LogisticRegression
viclafargue Oct 28, 2022
9bcc59f
Improve coverage for LinearRegression
viclafargue Oct 28, 2022
4ddb837
Improve coverage for lasso, elasticnet, ridge
viclafargue Oct 28, 2022
2a4427e
Update headers
viclafargue Oct 28, 2022
5802cfd
flake8 style
viclafargue Oct 28, 2022
f038485
Redirection to CPU
viclafargue Oct 28, 2022
1b21375
Update to_output tests
wphicks Oct 28, 2022
02b7e45
Fix output tests
wphicks Oct 28, 2022
e6c56a8
Update to pass all existing non-serialization array tests
wphicks Oct 28, 2022
7a784ac
Fix numerous minor issues
viclafargue Oct 31, 2022
ac71a2c
Trace memory type through context managers
wphicks Oct 31, 2022
e6a0fed
Move to inherited UniversalBase
wphicks Oct 31, 2022
cbb1fd7
Correct from_input usage in array_sparse
wphicks Oct 31, 2022
a85cdb2
Add memory type for sparse conversions
wphicks Nov 1, 2022
22b5435
Allow comparisons against other objects in array equality
wphicks Nov 1, 2022
508950f
Update array validation for backward compatibility
wphicks Nov 1, 2022
b19b31e
Correctly store cuml estimator attributes on device
wphicks Nov 1, 2022
07fa15a
Preserve order during array conversion
wphicks Nov 1, 2022
5bc59e0
Methods coverage for UMAP, PCA, tSVD and NN
viclafargue Nov 2, 2022
7f25eb0
two score functions in mixins
viclafargue Nov 2, 2022
63fa608
Fix header
viclafargue Nov 2, 2022
f89a416
Improve CumlArray conversion runtime
wphicks Nov 2, 2022
c288b1f
Cache dtype
wphicks Nov 2, 2022
ccad4cc
Revert debug change
wphicks Nov 2, 2022
2ca6fd2
Avoid extra copy for Series output
wphicks Nov 2, 2022
298da8c
Add optimized __iter__ method to array
wphicks Nov 2, 2022
c7d446f
Update docs for Base
wphicks Nov 2, 2022
1d138da
Update array docs
wphicks Nov 2, 2022
92018de
Update docs path
wphicks Nov 3, 2022
e7b9078
Merge branch 'branch-22.12' into fea-xpu_infra
wphicks Nov 3, 2022
9743a7a
Merge branch 'branch-22.12' into cpu-gpu-interop-models
viclafargue Nov 3, 2022
7ae2f92
Improve exception messages in safe imports
wphicks Nov 3, 2022
dc19310
better testing of LogisticRegression log_proba
viclafargue Nov 4, 2022
c9acaea
deepcopy for Ridge input
viclafargue Nov 4, 2022
36f5a12
skip some of the tests
viclafargue Nov 4, 2022
2e9d04d
avoid double kwargs processing when using a child class
viclafargue Nov 4, 2022
4a32589
restore tests
viclafargue Nov 4, 2022
c0d48ef
Fix remaining issues
viclafargue Nov 8, 2022
3a37b5a
minor improvements
viclafargue Nov 8, 2022
21002e6
Adress reviews
viclafargue Nov 10, 2022
ae11e35
Merge branch 'cpu-gpu-interop-models' into fea-xpu_infra
wphicks Nov 10, 2022
92e44a6
Fix base doc test
viclafargue Nov 11, 2022
06e0f8b
fix typo
viclafargue Nov 14, 2022
aea1143
Merge branch 'branch-22.12' into cpu-gpu-interop-models
wphicks Nov 15, 2022
ab66ac4
Merge branch 'cpu-gpu-interop-models' into fea-xpu_infra
wphicks Nov 15, 2022
615cdfd
Correct merge failures
wphicks Nov 15, 2022
0632ebc
address review
viclafargue Nov 16, 2022
888e975
Correct stride handling for serialization
wphicks Nov 16, 2022
2af7cf0
Update CumlArray __init__ for better backward compatibility
wphicks Nov 16, 2022
e358ac9
Default to global memory for pointer arrays
wphicks Nov 16, 2022
b35d2db
Merge branch 'cpu-gpu-interop-models' into fea-xpu_infra
wphicks Nov 16, 2022
f96e3cb
Merge branch 'branch-22.12' into dev-4918
wphicks Nov 16, 2022
50bb04e
Merge branch 'dev-4918' into fea-xpu_infra
wphicks Nov 16, 2022
43c3ccb
Correct UniversalBase merge
wphicks Nov 16, 2022
ad4dc5e
Remove autoarray decorator
wphicks Nov 16, 2022
a054328
Correct base docstring
wphicks Nov 16, 2022
462a68c
Update output_type list
wphicks Nov 16, 2022
1bc022c
Add missing input types to docs
wphicks Nov 16, 2022
8477085
Remove debug statements
wphicks Nov 16, 2022
53476cf
Correctly identify sparse arrays
wphicks Nov 17, 2022
45889a1
Correct memory type usage during transfer
wphicks Nov 17, 2022
c83dc99
Correct inheritance order in ridge regression
wphicks Nov 17, 2022
ed23281
Dispatch kneighbors_graph to correct device
wphicks Nov 17, 2022
42b53fa
Correct fail_oon_order behavior
wphicks Nov 17, 2022
fbbfbc4
Correct order handling for CumlArray conversions
wphicks Nov 17, 2022
960804f
Switch all input_to_host_array to using cuml_array
wphicks Nov 17, 2022
c47cb19
Do not short-circuit checks in CumlArray conversion
wphicks Nov 17, 2022
2044535
Handle convert_to_mem_type=False correctly
wphicks Nov 17, 2022
acb12eb
Reflect legacy behavior for convert_to_dtype
wphicks Nov 17, 2022
9832255
Correct contiguity check
wphicks Nov 17, 2022
024a771
Correct contiguity checks for single dim array
wphicks Nov 17, 2022
1738aeb
Keep index on correct device
wphicks Nov 18, 2022
cb35dac
Preserve index through all conversions
wphicks Nov 18, 2022
e615738
Revert to legacy output type string mapping
wphicks Nov 18, 2022
9a552c5
Remove debug prints
wphicks Nov 18, 2022
c242eb1
Correct handling of dataframes in testing
wphicks Nov 18, 2022
4870d2b
Add missing decorator for kneighbors
wphicks Nov 18, 2022
081a90a
Correct handling of pandas Series in input_utils
wphicks Nov 18, 2022
f254fa2
Suppress data_too_large healthcheck in Hypothesis-based test
wphicks Nov 18, 2022
c7d3143
Convert to current memory type on deserialization
wphicks Nov 18, 2022
221bf2e
Correct test for array memtype after deserialization
wphicks Nov 18, 2022
8b98fd8
Correct mem_type conversion in deserialization
wphicks Nov 21, 2022
8e9abb5
Clean up global_settings import
wphicks Nov 22, 2022
28ea4ef
Revert "Clean up global_settings import"
wphicks Nov 22, 2022
04262ec
Correct global_settings usage
wphicks Nov 22, 2022
8af3b33
Update copyright headers
wphicks Nov 22, 2022
1a80a87
Update style
wphicks Nov 23, 2022
71d1651
Restore missing import
wphicks Nov 23, 2022
666d169
Remove experimental internals directory
wphicks Nov 23, 2022
60b9eea
Correct docstring mismatch
wphicks Nov 23, 2022
4d0d8ae
Fix tests for expanded output type list
wphicks Nov 23, 2022
c0dc07d
Add missing error string construction
wphicks Nov 23, 2022
a7e512e
Incorporate review feedback
wphicks Nov 24, 2022
18395f1
Rename placeholder context
wphicks Nov 24, 2022
fb97dd7
Make safe import kwargs keyword-only
wphicks Nov 24, 2022
319fa6e
Update based on review feedback
wphicks Nov 24, 2022
7a80dc3
Convert output_type context manager to class
wphicks Nov 28, 2022
7998a60
Correct typo in context manager
wphicks Nov 28, 2022
ea9b9f0
Add memory type to estimator guide
wphicks Nov 28, 2022
4968387
Correct mem type reporting from base
wphicks Nov 28, 2022
6a0ee09
Correct is_array implementation for unavailable types
wphicks Nov 28, 2022
b42babd
Use is for type identity check
wphicks Nov 28, 2022
77fa78c
Provide better error handling for bad memory types
wphicks Nov 28, 2022
1bde3c0
Update style
wphicks Nov 29, 2022
d25d385
WIP: Implement hypothesis strategies and tests for arrays
csadorf Nov 21, 2022
454a073
Remove explicit testing of numba arrays.
csadorf Nov 22, 2022
b5de8d4
Continue implementation.
csadorf Nov 22, 2022
716e02f
Make create_cuml_array_input public function.
csadorf Nov 22, 2022
5ff60ad
Hypothesize test_get_set_item.
csadorf Nov 22, 2022
7665b23
Raise ValueError for invalid input to cuml_array_shapes.
csadorf Nov 22, 2022
f98739b
The cuml_array_shapes() strategy also returns integers.
csadorf Nov 22, 2022
5b990eb
Only run standard number of examples.
csadorf Nov 22, 2022
3f35726
Hypothesize test_create_empty.
csadorf Nov 22, 2022
03d0f0e
Reenable DeviceBuffer check since #4332 is resolved.
csadorf Nov 22, 2022
0b81196
Remove obsolete py<38 compatibility work-around.
csadorf Nov 22, 2022
b8774ac
Hypothesize test_create_* tests.
csadorf Nov 22, 2022
b214e77
Improve shape normalization and inspection.
csadorf Nov 22, 2022
b369b90
Hypothesize test_output test.
csadorf Nov 22, 2022
a698a23
Hypothesize test_output_dtype test.
csadorf Nov 22, 2022
573916a
Hypothesize test_cuda_array_interface test.
csadorf Nov 22, 2022
ccf730d
Hypothesize test_serialize test.
csadorf Nov 22, 2022
751391a
Hypothesize test_cumlary_binops and test_deepcopy tests.
csadorf Nov 22, 2022
0f95380
Improve cuml_arrays strategy (currently not used).
csadorf Nov 22, 2022
0782b4e
Cleanup test_array test module.
csadorf Nov 22, 2022
4ba3fc3
Use less rigorous mulit-dimension check for init_array.
csadorf Nov 22, 2022
a42da0a
Move test of array_inputs strategies into test_strategies module.
csadorf Nov 22, 2022
c9e02fa
Implement test_get_set_item with cuml_array_inputs.
csadorf Nov 24, 2022
cec283e
Implement test_output with cuml_array_inputs.
csadorf Nov 24, 2022
85ec179
Fix multidim check for test_output_dtype.
csadorf Nov 24, 2022
7051f93
Implement test_cuda_array_interface test with cuml_array_inputs.
csadorf Nov 24, 2022
9b3b7b4
Implement test_serialize with cuml_array_inputs.
csadorf Nov 24, 2022
d7e078b
Implement test_pickle with cuml_array_inputs.
csadorf Nov 24, 2022
49e2bb9
Implement test_deepcopy with cuml_array_inputs.
csadorf Nov 24, 2022
1f14486
Document new strategies.
csadorf Nov 28, 2022
0944b43
Remove obsolete None (default) value from valid cuml array input types.
csadorf Nov 29, 2022
732c7f7
Adjust cuml_array_shapes() max_side default value.
csadorf Nov 29, 2022
87c28c0
Apply isort and black formatting.
csadorf Nov 29, 2022
ec5f957
Remove _CUML_ARRAY_OUTPUT_DTYPES constant.
csadorf Nov 29, 2022
5c1d5b3
Remove todo comment (captured in discussion).
csadorf Nov 29, 2022
0d75b19
Merge branch 'branch-22.12' into fea-xpu_infra
wphicks Nov 29, 2022
6de2622
Correct branch 22.12 merge
wphicks Nov 30, 2022
816cf48
Get multi-output linear regression working with CPU infra
wphicks Nov 30, 2022
dec3907
Merge branch 'branch-22.12' into fea-xpu_infra
wphicks Nov 30, 2022
a399e85
The cuml_array_inputs strategy generates more arbitrary arrays.
csadorf Nov 30, 2022
f16d83a
Rename methods for current dispatch
wphicks Nov 30, 2022
bf7eb45
Correct usage of strides to order
wphicks Nov 30, 2022
4339d4b
Update style
wphicks Nov 30, 2022
143648f
Merge branch 'branch-23.02' into fea-xpu_infra
wphicks Dec 1, 2022
fe970bb
Remove outdated reference to experimentalBase
wphicks Dec 1, 2022
92c4e80
Indicate reason for conditional import of functools.cache
wphicks Dec 1, 2022
85ba5f6
More cleanly retrieve estimator handle.
wphicks Dec 1, 2022
e93d399
More cleanly retrieve estimator handle
wphicks Dec 1, 2022
7b5908e
Combine branches
wphicks Dec 1, 2022
3a65f27
Merge branch 'fea-hypothesis-stratgies-and-tests-for-arrays' into dev…
wphicks Dec 2, 2022
2345e62
Update array tests for memory type
wphicks Dec 2, 2022
4add7d9
Correctly get dtype from series in CumlArray conversion
wphicks Dec 2, 2022
c8cf98c
Correctly handle cupy strides bug everywhere
wphicks Dec 5, 2022
07a7c5d
Remove deadline on test
wphicks Dec 5, 2022
59d7928
Correctly convert to array in multi-output linear model
wphicks Dec 5, 2022
ea83dd7
Fix multi-output regression implementation after merge
wphicks Dec 5, 2022
7611b64
Correct issues uncovered by Hypothesis tests
wphicks Dec 5, 2022
f3837bb
Merge remote-tracking branch 'origin/fea-xpu_infra' into fea-xpu_infra
wphicks Dec 5, 2022
dd026d7
Incorporate review feedback
wphicks Dec 7, 2022
ded2318
Simplify null case statement
wphicks Dec 7, 2022
4cd94df
Update contiguity and order checks
wphicks Dec 7, 2022
939ed0c
Correct conditions under which order is computed
wphicks Dec 8, 2022
047e45a
Correctly handle order detection when beginning and end strides are t…
wphicks Dec 8, 2022
d415ee6
Merge branch 'dev-xpu_array_tests' into fea-xpu_infra
wphicks Dec 8, 2022
41f39b7
Merge branch 'branch-23.02' into fea-xpu_infra
wphicks Dec 8, 2022
d034efd
Correct handling of train/test split after merge
wphicks Dec 9, 2022
f448ae4
Provide more comprehensive handling of stride serialization
wphicks Dec 9, 2022
fc5ca24
Update style
wphicks Dec 9, 2022
8172a75
Import logger into common
wphicks Dec 9, 2022
85026ab
Revert change to CHANGELOG
wphicks Dec 9, 2022
43f8545
Merge branch 'branch-23.02' into fea-xpu_infra
wphicks Dec 9, 2022
cda3fbd
Correct __le__ method for CumlArray
wphicks Dec 12, 2022
c7988c0
Remove old using_memory_type and set_global_memory_type
wphicks Dec 12, 2022
71e3650
Remove old device type setters
wphicks Dec 12, 2022
3526df1
Add tests for stride computation and order detection
wphicks Dec 12, 2022
9916639
Merge branch 'branch-23.02' into fea-xpu_infra
wphicks Dec 12, 2022
ac3d542
Correct mistake in merge
wphicks Dec 13, 2022
4f61962
Improved CPU/GPU interoperability
viclafargue Dec 14, 2022
fa8b051
Modularization of the dispatch function
viclafargue Dec 14, 2022
78432b3
Merge branch-23.02
viclafargue Dec 15, 2022
a43d502
fixes
viclafargue Dec 16, 2022
15d3211
FIX Remove CumlArray from api.rst
dantegd Dec 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,6 @@ Dataset Generation (Dask-based Multi-GPU)
.. automodule:: cuml.dask.datasets.regression
:members:

Array Wrappers (Internal API)
-----------------------------

.. autoclass:: cuml.common.CumlArray
:members:

Metrics (regression, classification, and distance)
--------------------------------------------------
Expand Down
15 changes: 10 additions & 5 deletions python/cuml/decomposition/pca.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ from cuml.common.exceptions import NotFittedError
from cuml.internals.mixins import FMajorInputTagMixin
from cuml.internals.mixins import SparseInputTagMixin
from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.api_decorators import enable_device_interop


cdef extern from "cuml/decomposition/pca.hpp" namespace "ML":
Expand Down Expand Up @@ -406,7 +407,8 @@ class PCA(UniversalBase,
return self

@generate_docstring(X='dense_sparse')
def _fit(self, X, y=None) -> "PCA":
@enable_device_interop
def fit(self, X, y=None) -> "PCA":
"""
Fit the model with X. y is currently ignored.

Expand Down Expand Up @@ -496,7 +498,8 @@ class PCA(UniversalBase,
'description': 'Transformed values',
'shape': '(n_samples, n_components)'})
@cuml.internals.api_base_return_array_skipall
def _fit_transform(self, X, y=None) -> CumlArray:
@enable_device_interop
def fit_transform(self, X, y=None) -> CumlArray:
"""
Fit the model with X and apply the dimensionality reduction on X.

Expand Down Expand Up @@ -541,8 +544,9 @@ class PCA(UniversalBase,
'type': 'dense_sparse',
'description': 'Transformed values',
'shape': '(n_samples, n_features)'})
def _inverse_transform(self, X, convert_dtype=False,
return_sparse=False, sparse_tol=1e-10) -> CumlArray:
@enable_device_interop
def inverse_transform(self, X, convert_dtype=False,
return_sparse=False, sparse_tol=1e-10) -> CumlArray:
"""
Transform data back to its original space.

Expand Down Expand Up @@ -642,7 +646,8 @@ class PCA(UniversalBase,
'type': 'dense_sparse',
'description': 'Transformed values',
'shape': '(n_samples, n_components)'})
def _transform(self, X, convert_dtype=False) -> CumlArray:
@enable_device_interop
def transform(self, X, convert_dtype=False) -> CumlArray:
"""
Apply dimensionality reduction to X.

Expand Down
13 changes: 9 additions & 4 deletions python/cuml/decomposition/tsvd.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.common.doc_utils import generate_docstring
from cuml.internals.mixins import FMajorInputTagMixin
from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.api_decorators import enable_device_interop

from cython.operator cimport dereference as deref

Expand Down Expand Up @@ -299,7 +300,8 @@ class TruncatedSVD(UniversalBase,
dtype=self.dtype)

@generate_docstring()
def _fit(self, X, y=None) -> "TruncatedSVD":
@enable_device_interop
def fit(self, X, y=None) -> "TruncatedSVD":
"""
Fit LSI model on training cudf DataFrame X. y is currently ignored.

Expand All @@ -313,7 +315,8 @@ class TruncatedSVD(UniversalBase,
'type': 'dense',
'description': 'Reduced version of X',
'shape': '(n_samples, n_components)'})
def _fit_transform(self, X, y=None) -> CumlArray:
@enable_device_interop
def fit_transform(self, X, y=None) -> CumlArray:
"""
Fit LSI model to X and perform dimensionality reduction on X.
y is currently ignored.
Expand Down Expand Up @@ -377,7 +380,8 @@ class TruncatedSVD(UniversalBase,
'type': 'dense',
'description': 'X in original space',
'shape': '(n_samples, n_features)'})
def _inverse_transform(self, X, convert_dtype=False) -> CumlArray:
@enable_device_interop
def inverse_transform(self, X, convert_dtype=False) -> CumlArray:
"""
Transform X back to its original space.
Returns X_original whose transform would be X.
Expand Down Expand Up @@ -426,7 +430,8 @@ class TruncatedSVD(UniversalBase,
'type': 'dense',
'description': 'Reduced version of X',
'shape': '(n_samples, n_components)'})
def _transform(self, X, convert_dtype=False) -> CumlArray:
@enable_device_interop
def transform(self, X, convert_dtype=False) -> CumlArray:
"""
Perform dimensionality reduction on X.

Expand Down
40 changes: 15 additions & 25 deletions python/cuml/internals/api_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import typing
from functools import wraps
import warnings
from importlib import import_module

import cuml.internals.array
import cuml.internals.array_sparse
Expand Down Expand Up @@ -779,38 +778,17 @@ def inner_f(*args, **kwargs):

def device_interop_preparation(init_func):
"""
This function serves as a decorator to cuML estimators that implement
the CPU/GPU interoperability feature. It imports the joint CPU estimator
and processes the hyperparameters.
This function serves as a decorator for cuML estimators that implement
the CPU/GPU interoperability feature. It processes the estimator's
hyperparameters by saving them and filtering them for GPU execution.
"""

@functools.wraps(init_func)
def processor(self, *args, **kwargs):
# if child class (parent class was already decorated), skip
if hasattr(self, '_cpu_model_class'):
return init_func(self, *args, **kwargs)

if hasattr(self, '_cpu_estimator_import_path'):
# if import path differs from the one of sklearn
# look for _cpu_estimator_import_path
estimator_path = self._cpu_estimator_import_path.split('.')
model_path = '.'.join(estimator_path[:-1])
model_name = estimator_path[-1]
else:
# import from similar path to the current estimator
# class
model_path = 'sklearn' + self.__class__.__module__[4:]
model_name = self.__class__.__name__
self._cpu_model_class = getattr(import_module(model_path), model_name)

# Save all kwargs
self._full_kwargs = kwargs
# Generate list of available cuML hyperparameters
gpu_hyperparams = list(inspect.signature(init_func).parameters.keys())
# Save list of available CPU estimator hyperparameters
self._cpu_hyperparams = list(
inspect.signature(self._cpu_model_class.__init__).parameters.keys()
)

# Filter provided parameters for cuML estimator initialization
filtered_kwargs = {}
Expand All @@ -824,3 +802,15 @@ def processor(self, *args, **kwargs):

return init_func(self, *args, **filtered_kwargs)
return processor


def enable_device_interop(gpu_func):
@functools.wraps(gpu_func)
def dispatch(self, *args, **kwargs):
# check that the estimator implements CPU/GPU interoperability
if hasattr(self, 'dispatch_func'):
func_name = gpu_func.__name__
return self.dispatch_func(func_name, gpu_func, *args, **kwargs)
else:
return gpu_func(self, *args, **kwargs)
return dispatch
Loading