Skip to content

Commit

Permalink
Add rabit ops.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 1, 2020
1 parent d19cec7 commit 709ba29
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
12 changes: 12 additions & 0 deletions python-package/xgboost/rabit.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def get_world_size():
return ret


def is_distributed():
is_dist = _LIB.RabitIsDistributed()
return is_dist


def tracker_print(msg):
"""Print message to the tracker.
Expand Down Expand Up @@ -143,6 +148,13 @@ def broadcast(data, root):
}


class Op:
MAX = 0
MIN = 1
SUM = 2
OR = 3


def allreduce(data, op, prepare_fun=None):
"""Perform allreduce, return the result.
Expand Down
32 changes: 30 additions & 2 deletions tests/python/test_tracker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time

from xgboost import RabitTracker
import xgboost as xgb
import pytest
import testing as tm


def test_rabit_tracker():
Expand All @@ -15,3 +15,31 @@ def test_rabit_tracker():
ret = xgb.rabit.broadcast('test1234', 0)
assert str(ret) == 'test1234'
xgb.rabit.finalize()


def run_rabit_ops(client, n_workers):
from xgboost.dask import RabitContext, _get_rabit_args, _get_client_workers
from xgboost import rabit

workers = _get_client_workers(client)
rabit_args = _get_rabit_args(workers, client)
assert not rabit.is_distributed()

def local_test(worker_id):
with RabitContext(rabit_args):
a = 1
assert rabit.is_distributed()
reduced = rabit.allreduce(a, rabit.Op.SUM)
assert reduced[0] == n_workers

reduced = rabit.allreduce(worker_id, rabit.Op.MAX)
assert reduced == n_workers - 1


@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops():
from distributed import Client, LocalCluster
n_workers = 3
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)

0 comments on commit 709ba29

Please sign in to comment.