Skip to content

Commit

Permalink
API update to match RAFT PR rapidsai#120
Browse files Browse the repository at this point in the history
  • Loading branch information
drobison00 committed Jan 19, 2021
1 parent d72c54a commit 5139cf3
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 20 deletions.
7 changes: 4 additions & 3 deletions python/cuml/dask/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from cuml.dask.common.input_utils import DistributedDataHandler

from cuml.raft.dask.common.comms import Comms
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state

from cuml.dask.common.utils import wait_and_raise_from_futures

Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(self, client=None, verbose=False, **kwargs):
@mnmg_import
def _func_fit(sessionId, objs, datatype, **kwargs):
from cuml.cluster.kmeans_mg import KMeansMG as cumlKMeans
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]

inp_data = concatenate(objs)

Expand All @@ -132,7 +132,8 @@ def fit(self, X):
data = DistributedDataHandler.create(X, client=self.client)
self.datatype = data.datatype

comms = Comms(comms_p2p=False)
# This needs to happen on the scheduler
comms = Comms(comms_p2p=False, client=self.client)
comms.init(workers=data.workers)

kmeans_fit = [self.client.submit(KMeans._func_fit,
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/dask/decomposition/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#

from cuml.dask.common import raise_exception_from_futures
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state
from cuml.raft.dask.common.comms import Comms

from cuml.dask.common.input_utils import to_output
Expand Down Expand Up @@ -118,5 +118,5 @@ def _fit(self, X, _transform=False):

@staticmethod
def _create_model(sessionId, model_func, datatype, **kwargs):
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return model_func(handle, datatype, **kwargs)
4 changes: 2 additions & 2 deletions python/cuml/dask/linear_model/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from cuml.dask.common.base import DelayedPredictionMixin
from cuml.dask.common.base import mnmg_import
from cuml.dask.common.base import SyncFitMixinLinearModel
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state


class LinearRegression(BaseEstimator,
Expand Down Expand Up @@ -113,6 +113,6 @@ def get_param_names(self):
@mnmg_import
def _create_model(sessionId, datatype, **kwargs):
from cuml.linear_model.linear_regression_mg import LinearRegressionMG
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return LinearRegressionMG(handle=handle, output_type=datatype,
**kwargs)
4 changes: 2 additions & 2 deletions python/cuml/dask/linear_model/ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from cuml.dask.common.base import DelayedPredictionMixin
from cuml.dask.common.base import mnmg_import
from cuml.dask.common.base import SyncFitMixinLinearModel
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state


class Ridge(BaseEstimator,
Expand Down Expand Up @@ -124,5 +124,5 @@ def get_param_names(self):
@mnmg_import
def _create_model(sessionId, datatype, **kwargs):
from cuml.linear_model.ridge_mg import RidgeMG
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return RidgeMG(handle=handle, output_type=datatype, **kwargs)
4 changes: 2 additions & 2 deletions python/cuml/dask/neighbors/kneighbors_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from cuml.dask.common import flatten_grouped_results
from cuml.dask.common.utils import raise_mg_import_exception
from cuml.dask.common.utils import wait_and_raise_from_futures
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state
from cuml.dask.neighbors import NearestNeighbors
from dask.dataframe import Series as DaskSeries
import dask.array as da
Expand Down Expand Up @@ -124,7 +124,7 @@ def _func_create_model(sessionId, **kwargs):
except ImportError:
raise_mg_import_exception()

handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return cumlKNN(handle=handle, **kwargs)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/dask/neighbors/kneighbors_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from cuml.dask.common import flatten_grouped_results
from cuml.dask.common.utils import raise_mg_import_exception
from cuml.dask.common.utils import wait_and_raise_from_futures
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state
from cuml.dask.neighbors import NearestNeighbors
import dask.array as da
from uuid import uuid1
Expand Down Expand Up @@ -101,7 +101,7 @@ def _func_create_model(sessionId, **kwargs):
except ImportError:
raise_mg_import_exception()

handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return cumlKNN(handle=handle, **kwargs)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/dask/neighbors/nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from cuml.dask.common import raise_mg_import_exception
from cuml.dask.common.base import BaseEstimator

from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state
from cuml.raft.dask.common.comms import Comms
from cuml.dask.common.input_utils import to_output
from cuml.dask.common.input_utils import DistributedDataHandler
Expand Down Expand Up @@ -102,7 +102,7 @@ def _func_create_model(sessionId, **kwargs):
except ImportError:
raise_mg_import_exception()

handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return cumlNN(handle=handle, **kwargs)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/dask/solvers/cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from cuml.dask.common.base import DelayedPredictionMixin
from cuml.dask.common.base import mnmg_import
from cuml.dask.common.base import SyncFitMixinLinearModel
from cuml.raft.dask.common.comms import worker_state
from cuml.raft.dask.common.comms import get_raft_comm_state


class CD(BaseEstimator,
Expand Down Expand Up @@ -78,5 +78,5 @@ def predict(self, X, delayed=True):
@mnmg_import
def _create_model(sessionId, datatype, **kwargs):
from cuml.solvers.cd_mg import CDMG
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return CDMG(handle=handle, output_type=datatype, **kwargs)
6 changes: 3 additions & 3 deletions wiki/mnmg/Using_Infiniband_for_MNMG.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ dask-cuda-worker ucx://10.0.0.50:8786
```python
from dask.distributed import Client, wait
from cuml.raft.dask.common.comms import Comms
from cuml.dask.common import worker_state
from cuml.dask.common import get_raft_comm_state
from cuml.dask.common import perform_test_comms_send_recv
from cuml.dask.common import perform_test_comms_allreduce

Expand All @@ -312,7 +312,7 @@ cb.init()
n_trials = 2

def func_test_send_recv(sessionId, n_trials, r):
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return perform_test_comms_send_recv(handle, n_trials)

p2p_dfs=[c.submit(func_test_send_recv, cb.sessionId, n_trials, random.random(), workers=[w]) for wid, w in zip(range(len(cb.worker_addresses)), cb.worker_addresses)]
Expand Down Expand Up @@ -355,7 +355,7 @@ Rank 7 received: [2, 11, 12, 9, 10, 13, 14, 8, 15, 4, 1, 6, 5, 0, 3]
### Test collective communications:
```python
def func_test_allreduce(sessionId, r):
handle = worker_state(sessionId)["handle"]
handle = get_raft_comm_state(sessionId)["handle"]
return perform_test_comms_allreduce(handle)

coll_dfs = [c.submit(func_test_allreduce, cb.sessionId, random.random(), workers=[w]) for wid, w in zip(range(len(cb.worker_addresses)), cb.worker_addresses)]
Expand Down

0 comments on commit 5139cf3

Please sign in to comment.