Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] Order the prediction result. #5416

Merged
merged 4 commits into from
Mar 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 67 additions & 32 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,6 @@ class DaskDMatrix:

'''

_feature_names = None # for previous version's pickle
_feature_types = None

def __init__(self,
client,
data,
Expand All @@ -153,9 +150,9 @@ def __init__(self,
_assert_dask_support()
_assert_client(client)

self._feature_names = feature_names
self._feature_types = feature_types
self._missing = missing
self.feature_names = feature_names
self.feature_types = feature_types
self.missing = missing

if len(data.shape) != 2:
raise ValueError(
Expand Down Expand Up @@ -237,6 +234,10 @@ def check_columns(parts):
for part in parts:
assert part.status == 'finished'

self.partition_order = {}
for i, part in enumerate(parts):
self.partition_order[part.key] = i

key_to_partition = {part.key: part for part in parts}
who_has = await client.scheduler.who_has(
keys=[part.key for part in parts])
Expand All @@ -247,6 +248,16 @@ def check_columns(parts):

self.worker_map = worker_map

def get_worker_x_ordered(self, worker):
list_of_parts = self.worker_map[worker.address]
client = get_client()
list_of_parts_value = client.gather(list_of_parts)
result = []
for i, part in enumerate(list_of_parts):
result.append((list_of_parts_value[i][0],
self.partition_order[part.key]))
return result

def get_worker_parts(self, worker):
'''Get mapped parts of data in each worker.'''
list_of_parts = self.worker_map[worker.address]
Expand Down Expand Up @@ -289,8 +300,8 @@ def get_worker_data(self, worker):
workers=set(self.worker_map.keys()))
logging.warning(msg)
d = DMatrix(numpy.empty((0, 0)),
feature_names=self._feature_names,
feature_types=self._feature_types)
feature_names=self.feature_names,
feature_types=self.feature_types)
return d

data, labels, weights = self.get_worker_parts(worker)
Expand All @@ -308,9 +319,9 @@ def get_worker_data(self, worker):
dmatrix = DMatrix(data,
labels,
weight=weights,
missing=self._missing,
feature_names=self._feature_names,
feature_types=self._feature_types)
missing=self.missing,
feature_names=self.feature_names,
feature_types=self.feature_types)
return dmatrix

def get_worker_data_shape(self, worker):
Expand Down Expand Up @@ -457,41 +468,65 @@ def predict(client, model, data, *args):
worker_map = data.worker_map
client = _xgb_get_client(client)

rabit_args = _get_rabit_args(worker_map, client)
missing = data.missing
feature_names = data.feature_names
feature_types = data.feature_types

def dispatched_predict(worker_id):
'''Perform prediction on each worker.'''
logging.info('Predicting on %d', worker_id)
worker = distributed_get_worker()
local_x = data.get_worker_data(worker)

with RabitContext(rabit_args):
local_predictions = booster.predict(
data=local_x, validate_features=local_x.num_row() != 0, *args)
return local_predictions

futures = client.map(dispatched_predict,
range(len(worker_map)),
pure=False,
workers=list(worker_map.keys()))
list_of_parts = data.get_worker_x_ordered(worker)
predictions = []
for part, order in list_of_parts:
local_x = DMatrix(part,
feature_names=feature_names,
feature_types=feature_types,
missing=missing)
predt = booster.predict(data=local_x,
validate_features=local_x.num_row() != 0,
*args)
ret = (delayed(predt), order)
predictions.append(ret)
return predictions

def dispatched_get_shape(worker_id):
'''Get shape of data in each worker.'''
logging.info('Trying to get data shape on %d', worker_id)
worker = distributed_get_worker()
rows, _ = data.get_worker_data_shape(worker)
return rows, 1 # default is 1
list_of_parts = data.get_worker_x_ordered(worker)
shapes = []
for part, order in list_of_parts:
s = part.shape
shapes.append((s, order))
return shapes

def map_function(func):
'''Run function for each part of the data.'''
futures = []
for wid in range(len(worker_map)):
list_of_workers = [list(worker_map.keys())[wid]]
f = client.submit(func, wid,
pure=False,
workers=list_of_workers)
futures.append(f)

# Get delayed objects
results = client.gather(futures)
results = [t for l in results for t in l] # flatten into 1 dim list
# sort by order, l[0] is the delayed object, l[1] is its order
results = sorted(results, key=lambda l: l[1])
results = [predt for predt, order in results] # remove order
return results

results = map_function(dispatched_predict)
shapes = map_function(dispatched_get_shape)

# Constructing a dask array from list of numpy arrays
# See https://docs.dask.org/en/latest/array-creation.html
futures_shape = client.map(dispatched_get_shape,
range(len(worker_map)),
pure=False,
workers=list(worker_map.keys()))
shapes = client.gather(futures_shape)
arrays = []
for i in range(len(futures_shape)):
arrays.append(da.from_delayed(futures[i], shape=(shapes[i][0], ),
for i, shape in enumerate(shapes):
arrays.append(da.from_delayed(results[i], shape=(shape[0], ),
dtype=numpy.float32))
predictions = da.concatenate(arrays, axis=0)
return predictions
Expand Down
5 changes: 5 additions & 0 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def test_from_dask_array():
# force prediction to be computed
prediction = prediction.compute()

single_node_predt = result['booster'].predict(
xgb.DMatrix(X.compute())
)
np.testing.assert_allclose(prediction, single_node_predt)


def test_dask_regressor():
with LocalCluster(n_workers=5) as cluster:
Expand Down