Skip to content

Commit

Permalink
[REVIEW] API update to match RAFT PR #120 (#3386)
Browse files Browse the repository at this point in the history
* API update to match RAFT PR #120

* Update raft git tag

* Update copyrights

* Copyright update

Co-authored-by: Devin Robison <[email protected]>
  • Loading branch information
drobison00 and drobison00 authored Jan 24, 2021
1 parent 4ce2380 commit 6b5e7ff
Show file tree
Hide file tree
Showing 14 changed files with 36 additions and 35 deletions.
4 changes: 2 additions & 2 deletions cpp/cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#=============================================================================
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -39,7 +39,7 @@ else(DEFINED ENV{RAFT_PATH})

ExternalProject_Add(raft
GIT_REPOSITORY https://github.com/rapidsai/raft.git
GIT_TAG 9161d7a238aca859453d8517bd7ad92cbd902f6a
GIT_TAG 9dbf2c8a9134ce8135f7fe947ec523d874fcab6a
PREFIX ${RAFT_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
Expand Down
9 changes: 5 additions & 4 deletions python/cuml/dask/cluster/kmeans.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2020, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -24,7 +24,7 @@
from cuml.dask.common.input_utils import DistributedDataHandler

from cuml.raft.dask.common.comms import Comms
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state

from cuml.dask.common.utils import wait_and_raise_from_futures

Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(self, client=None, verbose=False, **kwargs):
@mnmg_import
def _func_fit(sessionId, objs, datatype, **kwargs):
from cuml.cluster.kmeans_mg import KMeansMG as cumlKMeans
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]

inp_data = concatenate(objs)

Expand All @@ -132,7 +132,8 @@ def fit(self, X):
data = DistributedDataHandler.create(X, client=self.client)
self.datatype = data.datatype

comms = Comms(comms_p2p=False)
# This needs to happen on the scheduler
comms = Comms(comms_p2p=False, client=self.client)
comms.init(workers=data.workers)

kmeans_fit = [self.client.submit(KMeans._func_fit,
Expand Down
6 changes: 3 additions & 3 deletions python/cuml/dask/decomposition/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,7 +14,7 @@
#

from cuml.dask.common import raise_exception_from_futures
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state
from cuml.raft.dask.common.comms import Comms

from cuml.dask.common.input_utils import to_output
Expand Down Expand Up @@ -118,5 +118,5 @@ def _fit(self, X, _transform=False):

@staticmethod
def _create_model(sessionId, model_func, datatype, **kwargs):
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return model_func(handle, datatype, **kwargs)
2 changes: 1 addition & 1 deletion python/cuml/dask/ensemble/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/dask/ensemble/randomforestclassifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2020, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/dask/ensemble/randomforestregressor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2020, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
6 changes: 3 additions & 3 deletions python/cuml/dask/linear_model/linear_regression.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2020, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,7 @@
from cuml.dask.common.base import DelayedPredictionMixin
from cuml.dask.common.base import mnmg_import
from cuml.dask.common.base import SyncFitMixinLinearModel
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state


class LinearRegression(BaseEstimator,
Expand Down Expand Up @@ -113,6 +113,6 @@ def get_param_names(self):
@mnmg_import
def _create_model(sessionId, datatype, **kwargs):
from cuml.linear_model.linear_regression_mg import LinearRegressionMG
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return LinearRegressionMG(handle=handle, output_type=datatype,
**kwargs)
6 changes: 3 additions & 3 deletions python/cuml/dask/linear_model/ridge.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,7 @@
from cuml.dask.common.base import DelayedPredictionMixin
from cuml.dask.common.base import mnmg_import
from cuml.dask.common.base import SyncFitMixinLinearModel
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state


class Ridge(BaseEstimator,
Expand Down Expand Up @@ -124,5 +124,5 @@ def get_param_names(self):
@mnmg_import
def _create_model(sessionId, datatype, **kwargs):
from cuml.linear_model.ridge_mg import RidgeMG
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return RidgeMG(handle=handle, output_type=datatype, **kwargs)
6 changes: 3 additions & 3 deletions python/cuml/dask/neighbors/kneighbors_classifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,7 +20,7 @@
from cuml.dask.common import flatten_grouped_results
from cuml.dask.common.utils import raise_mg_import_exception
from cuml.dask.common.utils import wait_and_raise_from_futures
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state
from cuml.dask.neighbors import NearestNeighbors
from dask.dataframe import Series as DaskSeries
import dask.array as da
Expand Down Expand Up @@ -124,7 +124,7 @@ def _func_create_model(sessionId, **kwargs):
except ImportError:
raise_mg_import_exception()

handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return cumlKNN(handle=handle, **kwargs)

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions python/cuml/dask/neighbors/kneighbors_regressor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,7 +20,7 @@
from cuml.dask.common import flatten_grouped_results
from cuml.dask.common.utils import raise_mg_import_exception
from cuml.dask.common.utils import wait_and_raise_from_futures
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state
from cuml.dask.neighbors import NearestNeighbors
import dask.array as da
from uuid import uuid1
Expand Down Expand Up @@ -101,7 +101,7 @@ def _func_create_model(sessionId, **kwargs):
except ImportError:
raise_mg_import_exception()

handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return cumlKNN(handle=handle, **kwargs)

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions python/cuml/dask/neighbors/nearest_neighbors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,7 +19,7 @@
from cuml.dask.common import raise_mg_import_exception
from cuml.dask.common.base import BaseEstimator

from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state
from cuml.raft.dask.common.comms import Comms
from cuml.dask.common.input_utils import to_output
from cuml.dask.common.input_utils import DistributedDataHandler
Expand Down Expand Up @@ -102,7 +102,7 @@ def _func_create_model(sessionId, **kwargs):
except ImportError:
raise_mg_import_exception()

handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return cumlNN(handle=handle, **kwargs)

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions python/cuml/dask/solvers/cd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,7 @@
from cuml.dask.common.base import DelayedPredictionMixin
from cuml.dask.common.base import mnmg_import
from cuml.dask.common.base import SyncFitMixinLinearModel
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state


class CD(BaseEstimator,
Expand Down Expand Up @@ -78,5 +78,5 @@ def predict(self, X, delayed=True):
@mnmg_import
def _create_model(sessionId, datatype, **kwargs):
from cuml.solvers.cd_mg import CDMG
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return CDMG(handle=handle, output_type=datatype, **kwargs)
4 changes: 2 additions & 2 deletions python/cuml/test/dask/test_random_forest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

# Copyright (c) 2019-2020, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,7 +15,7 @@
#


# Copyright (c) 2019-2020, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
6 changes: 3 additions & 3 deletions wiki/mnmg/Using_Infiniband_for_MNMG.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ dask-cuda-worker ucx://10.0.0.50:8786
```python
from dask.distributed import Client, wait
from cuml.raft.dask.common.comms import Comms
from cuml.dask.common import worker_state
from cuml.dask.common import get_raft_comm_state
from cuml.dask.common import perform_test_comms_send_recv
from cuml.dask.common import perform_test_comms_allreduce

Expand All @@ -312,7 +312,7 @@ cb.init()
n_trials = 2

def func_test_send_recv(sessionId, n_trials, r):
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return perform_test_comms_send_recv(handle, n_trials)

p2p_dfs=[c.submit(func_test_send_recv, cb.sessionId, n_trials, random.random(), workers=[w]) for wid, w in zip(range(len(cb.worker_addresses)), cb.worker_addresses)]
Expand Down Expand Up @@ -355,7 +355,7 @@ Rank 7 received: [2, 11, 12, 9, 10, 13, 14, 8, 15, 4, 1, 6, 5, 0, 3]
### Test collective communications:
```python
def func_test_allreduce(sessionId, r):
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return perform_test_comms_allreduce(handle)

coll_dfs = [c.submit(func_test_allreduce, cb.sessionId, random.random(), workers=[w]) for wid, w in zip(range(len(cb.worker_addresses)), cb.worker_addresses)]
Expand Down

0 comments on commit 6b5e7ff

Please sign in to comment.