Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] API update to match RAFT PR #120 #3386

Merged
merged 5 commits into from
Jan 24, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions python/cuml/dask/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from cuml.dask.common.input_utils import DistributedDataHandler

from cuml.raft.dask.common.comms import Comms
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state

from cuml.dask.common.utils import wait_and_raise_from_futures

Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(self, client=None, verbose=False, **kwargs):
@mnmg_import
def _func_fit(sessionId, objs, datatype, **kwargs):
from cuml.cluster.kmeans_mg import KMeansMG as cumlKMeans
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]

inp_data = concatenate(objs)

Expand All @@ -132,7 +132,8 @@ def fit(self, X):
data = DistributedDataHandler.create(X, client=self.client)
self.datatype = data.datatype

comms = Comms(comms_p2p=False)
# This needs to happen on the scheduler
comms = Comms(comms_p2p=False, client=self.client)
comms.init(workers=data.workers)

kmeans_fit = [self.client.submit(KMeans._func_fit,
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/dask/decomposition/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#

from cuml.dask.common import raise_exception_from_futures
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state
from cuml.raft.dask.common.comms import Comms

from cuml.dask.common.input_utils import to_output
Expand Down Expand Up @@ -118,5 +118,5 @@ def _fit(self, X, _transform=False):

@staticmethod
def _create_model(sessionId, model_func, datatype, **kwargs):
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return model_func(handle, datatype, **kwargs)
4 changes: 2 additions & 2 deletions python/cuml/dask/linear_model/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from cuml.dask.common.base import DelayedPredictionMixin
from cuml.dask.common.base import mnmg_import
from cuml.dask.common.base import SyncFitMixinLinearModel
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state


class LinearRegression(BaseEstimator,
Expand Down Expand Up @@ -113,6 +113,6 @@ def get_param_names(self):
@mnmg_import
def _create_model(sessionId, datatype, **kwargs):
from cuml.linear_model.linear_regression_mg import LinearRegressionMG
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return LinearRegressionMG(handle=handle, output_type=datatype,
**kwargs)
4 changes: 2 additions & 2 deletions python/cuml/dask/linear_model/ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from cuml.dask.common.base import DelayedPredictionMixin
from cuml.dask.common.base import mnmg_import
from cuml.dask.common.base import SyncFitMixinLinearModel
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state


class Ridge(BaseEstimator,
Expand Down Expand Up @@ -124,5 +124,5 @@ def get_param_names(self):
@mnmg_import
def _create_model(sessionId, datatype, **kwargs):
from cuml.linear_model.ridge_mg import RidgeMG
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return RidgeMG(handle=handle, output_type=datatype, **kwargs)
4 changes: 2 additions & 2 deletions python/cuml/dask/neighbors/kneighbors_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from cuml.dask.common import flatten_grouped_results
from cuml.dask.common.utils import raise_mg_import_exception
from cuml.dask.common.utils import wait_and_raise_from_futures
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state
from cuml.dask.neighbors import NearestNeighbors
from dask.dataframe import Series as DaskSeries
import dask.array as da
Expand Down Expand Up @@ -124,7 +124,7 @@ def _func_create_model(sessionId, **kwargs):
except ImportError:
raise_mg_import_exception()

handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return cumlKNN(handle=handle, **kwargs)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/dask/neighbors/kneighbors_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from cuml.dask.common import flatten_grouped_results
from cuml.dask.common.utils import raise_mg_import_exception
from cuml.dask.common.utils import wait_and_raise_from_futures
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state
from cuml.dask.neighbors import NearestNeighbors
import dask.array as da
from uuid import uuid1
Expand Down Expand Up @@ -101,7 +101,7 @@ def _func_create_model(sessionId, **kwargs):
except ImportError:
raise_mg_import_exception()

handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return cumlKNN(handle=handle, **kwargs)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/dask/neighbors/nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from cuml.dask.common import raise_mg_import_exception
from cuml.dask.common.base import BaseEstimator

from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state
from cuml.raft.dask.common.comms import Comms
from cuml.dask.common.input_utils import to_output
from cuml.dask.common.input_utils import DistributedDataHandler
Expand Down Expand Up @@ -102,7 +102,7 @@ def _func_create_model(sessionId, **kwargs):
except ImportError:
raise_mg_import_exception()

handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return cumlNN(handle=handle, **kwargs)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/dask/solvers/cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from cuml.dask.common.base import DelayedPredictionMixin
from cuml.dask.common.base import mnmg_import
from cuml.dask.common.base import SyncFitMixinLinearModel
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state


class CD(BaseEstimator,
Expand Down Expand Up @@ -78,5 +78,5 @@ def predict(self, X, delayed=True):
@mnmg_import
def _create_model(sessionId, datatype, **kwargs):
from cuml.solvers.cd_mg import CDMG
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return CDMG(handle=handle, output_type=datatype, **kwargs)
6 changes: 3 additions & 3 deletions wiki/mnmg/Using_Infiniband_for_MNMG.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ dask-cuda-worker ucx://10.0.0.50:8786
```python
from dask.distributed import Client, wait
from cuml.raft.dask.common.comms import Comms
from cuml.dask.common import worker_state
from cuml.dask.common import get_raft_comm_state
from cuml.dask.common import perform_test_comms_send_recv
from cuml.dask.common import perform_test_comms_allreduce

Expand All @@ -312,7 +312,7 @@ cb.init()
n_trials = 2

def func_test_send_recv(sessionId, n_trials, r):
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return perform_test_comms_send_recv(handle, n_trials)

p2p_dfs=[c.submit(func_test_send_recv, cb.sessionId, n_trials, random.random(), workers=[w]) for wid, w in zip(range(len(cb.worker_addresses)), cb.worker_addresses)]
Expand Down Expand Up @@ -355,7 +355,7 @@ Rank 7 received: [2, 11, 12, 9, 10, 13, 14, 8, 15, 4, 1, 6, 5, 0, 3]
### Test collective communications:
```python
def func_test_allreduce(sessionId, r):
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return perform_test_comms_allreduce(handle)

coll_dfs = [c.submit(func_test_allreduce, cb.sessionId, random.random(), workers=[w]) for wid, w in zip(range(len(cb.worker_addresses)), cb.worker_addresses)]
Expand Down