Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] API update to match RAFT PR #120 #3386

Merged
merged 5 commits into from
Jan 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#=============================================================================
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -39,7 +39,7 @@ else(DEFINED ENV{RAFT_PATH})

ExternalProject_Add(raft
GIT_REPOSITORY https://github.com/rapidsai/raft.git
GIT_TAG 9161d7a238aca859453d8517bd7ad92cbd902f6a
GIT_TAG 9dbf2c8a9134ce8135f7fe947ec523d874fcab6a
PREFIX ${RAFT_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
Expand Down
9 changes: 5 additions & 4 deletions python/cuml/dask/cluster/kmeans.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2020, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -24,7 +24,7 @@
from cuml.dask.common.input_utils import DistributedDataHandler

from cuml.raft.dask.common.comms import Comms
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state

from cuml.dask.common.utils import wait_and_raise_from_futures

Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(self, client=None, verbose=False, **kwargs):
@mnmg_import
def _func_fit(sessionId, objs, datatype, **kwargs):
from cuml.cluster.kmeans_mg import KMeansMG as cumlKMeans
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]

inp_data = concatenate(objs)

Expand All @@ -132,7 +132,8 @@ def fit(self, X):
data = DistributedDataHandler.create(X, client=self.client)
self.datatype = data.datatype

comms = Comms(comms_p2p=False)
# This needs to happen on the scheduler
comms = Comms(comms_p2p=False, client=self.client)
comms.init(workers=data.workers)

kmeans_fit = [self.client.submit(KMeans._func_fit,
Expand Down
6 changes: 3 additions & 3 deletions python/cuml/dask/decomposition/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,7 +14,7 @@
#

from cuml.dask.common import raise_exception_from_futures
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state
from cuml.raft.dask.common.comms import Comms

from cuml.dask.common.input_utils import to_output
Expand Down Expand Up @@ -118,5 +118,5 @@ def _fit(self, X, _transform=False):

@staticmethod
def _create_model(sessionId, model_func, datatype, **kwargs):
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return model_func(handle, datatype, **kwargs)
2 changes: 1 addition & 1 deletion python/cuml/dask/ensemble/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/dask/ensemble/randomforestclassifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2020, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/dask/ensemble/randomforestregressor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2020, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
6 changes: 3 additions & 3 deletions python/cuml/dask/linear_model/linear_regression.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2020, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,7 @@
from cuml.dask.common.base import DelayedPredictionMixin
from cuml.dask.common.base import mnmg_import
from cuml.dask.common.base import SyncFitMixinLinearModel
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state


class LinearRegression(BaseEstimator,
Expand Down Expand Up @@ -113,6 +113,6 @@ def get_param_names(self):
@mnmg_import
def _create_model(sessionId, datatype, **kwargs):
from cuml.linear_model.linear_regression_mg import LinearRegressionMG
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return LinearRegressionMG(handle=handle, output_type=datatype,
**kwargs)
6 changes: 3 additions & 3 deletions python/cuml/dask/linear_model/ridge.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,7 @@
from cuml.dask.common.base import DelayedPredictionMixin
from cuml.dask.common.base import mnmg_import
from cuml.dask.common.base import SyncFitMixinLinearModel
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state


class Ridge(BaseEstimator,
Expand Down Expand Up @@ -124,5 +124,5 @@ def get_param_names(self):
@mnmg_import
def _create_model(sessionId, datatype, **kwargs):
from cuml.linear_model.ridge_mg import RidgeMG
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return RidgeMG(handle=handle, output_type=datatype, **kwargs)
6 changes: 3 additions & 3 deletions python/cuml/dask/neighbors/kneighbors_classifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,7 +20,7 @@
from cuml.dask.common import flatten_grouped_results
from cuml.dask.common.utils import raise_mg_import_exception
from cuml.dask.common.utils import wait_and_raise_from_futures
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state
from cuml.dask.neighbors import NearestNeighbors
from dask.dataframe import Series as DaskSeries
import dask.array as da
Expand Down Expand Up @@ -124,7 +124,7 @@ def _func_create_model(sessionId, **kwargs):
except ImportError:
raise_mg_import_exception()

handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return cumlKNN(handle=handle, **kwargs)

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions python/cuml/dask/neighbors/kneighbors_regressor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,7 +20,7 @@
from cuml.dask.common import flatten_grouped_results
from cuml.dask.common.utils import raise_mg_import_exception
from cuml.dask.common.utils import wait_and_raise_from_futures
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state
from cuml.dask.neighbors import NearestNeighbors
import dask.array as da
from uuid import uuid1
Expand Down Expand Up @@ -101,7 +101,7 @@ def _func_create_model(sessionId, **kwargs):
except ImportError:
raise_mg_import_exception()

handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return cumlKNN(handle=handle, **kwargs)

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions python/cuml/dask/neighbors/nearest_neighbors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,7 +19,7 @@
from cuml.dask.common import raise_mg_import_exception
from cuml.dask.common.base import BaseEstimator

from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state
from cuml.raft.dask.common.comms import Comms
from cuml.dask.common.input_utils import to_output
from cuml.dask.common.input_utils import DistributedDataHandler
Expand Down Expand Up @@ -102,7 +102,7 @@ def _func_create_model(sessionId, **kwargs):
except ImportError:
raise_mg_import_exception()

handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return cumlNN(handle=handle, **kwargs)

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions python/cuml/dask/solvers/cd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,7 @@
from cuml.dask.common.base import DelayedPredictionMixin
from cuml.dask.common.base import mnmg_import
from cuml.dask.common.base import SyncFitMixinLinearModel
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state


class CD(BaseEstimator,
Expand Down Expand Up @@ -78,5 +78,5 @@ def predict(self, X, delayed=True):
@mnmg_import
def _create_model(sessionId, datatype, **kwargs):
from cuml.solvers.cd_mg import CDMG
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return CDMG(handle=handle, output_type=datatype, **kwargs)
4 changes: 2 additions & 2 deletions python/cuml/test/dask/test_random_forest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

# Copyright (c) 2019-2020, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,7 +15,7 @@
#


# Copyright (c) 2019-2020, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
6 changes: 3 additions & 3 deletions wiki/mnmg/Using_Infiniband_for_MNMG.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ dask-cuda-worker ucx://10.0.0.50:8786
```python
from dask.distributed import Client, wait
from cuml.raft.dask.common.comms import Comms
from cuml.dask.common import worker_state
from cuml.dask.common import get_raft_comm_state
from cuml.dask.common import perform_test_comms_send_recv
from cuml.dask.common import perform_test_comms_allreduce

Expand All @@ -312,7 +312,7 @@ cb.init()
n_trials = 2

def func_test_send_recv(sessionId, n_trials, r):
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return perform_test_comms_send_recv(handle, n_trials)

p2p_dfs=[c.submit(func_test_send_recv, cb.sessionId, n_trials, random.random(), workers=[w]) for wid, w in zip(range(len(cb.worker_addresses)), cb.worker_addresses)]
Expand Down Expand Up @@ -355,7 +355,7 @@ Rank 7 received: [2, 11, 12, 9, 10, 13, 14, 8, 15, 4, 1, 6, 5, 0, 3]
### Test collective communications:
```python
def func_test_allreduce(sessionId, r):
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return perform_test_comms_allreduce(handle)

coll_dfs = [c.submit(func_test_allreduce, cb.sessionId, random.random(), workers=[w]) for wid, w in zip(range(len(cb.worker_addresses)), cb.worker_addresses)]
Expand Down