Skip to content

Commit

Permalink
Fixture.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 10, 2020
1 parent 04f87fd commit 68e5169
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@

try:
from distributed import LocalCluster, Client
from distributed.utils_test import client, loop, cluster_fixture
import dask.dataframe as dd
import dask.array as da
from xgboost.dask import DaskDMatrix
except ImportError:
LocalCluster = None
Client = None
client = None
loop = None
cluster_fixture = None
dd = None
da = None
DaskDMatrix = None
Expand Down Expand Up @@ -452,7 +456,7 @@ def test_with_asyncio():
asyncio.run(run_dask_classifier_asyncio(address))


class TestWithDask(unittest.TestCase):
class TestWithDask:
def run_updater_test(self, client, params, num_rounds, dataset,
tree_method):
params['tree_method'] = tree_method
Expand Down Expand Up @@ -483,31 +487,29 @@ def run_updater_test(self, client, params, num_rounds, dataset,
note(history)
assert tm.non_increasing(history['train'][dataset.metric])

@given(hist_parameter_strategy, strategies.integers(10, 20),
tm.dataset_strategy)
@given(params=hist_parameter_strategy,
num_rounds=strategies.integers(10, 20),
dataset=tm.dataset_strategy)
@settings(deadline=None)
def test_hist(self, params, num_rounds, dataset):
with LocalCluster() as cluster:
with Client(cluster) as client:
self.run_updater_test(
client, params, num_rounds, dataset, 'hist')
def test_hist(self, params, num_rounds, dataset, client):
self.run_updater_test(client, params, num_rounds, dataset, 'hist')

@given(exact_parameter_strategy, strategies.integers(10, 20),
tm.dataset_strategy)
@given(params=exact_parameter_strategy,
num_rounds=strategies.integers(10, 20),
dataset=tm.dataset_strategy)
@settings(deadline=None)
def test_approx(self, params, num_rounds, dataset):
with LocalCluster() as cluster:
with Client(cluster) as client:
self.run_updater_test(
client, params, num_rounds, dataset, 'approx')
def test_approx(self, client, params, num_rounds, dataset):
self.run_updater_test(client, params, num_rounds, dataset, 'approx')

def run_quantile(self, name):
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows")

exe = None
for possible_path in {'./testxgboost', './build/testxgboost',
'../build/testxgboost', '../cpu-build/testxgboost'}:
'../build/testxgboost',
'../cpu-build/testxgboost',
'../gpu-build/testxgboost'}:
if os.path.exists(possible_path):
exe = possible_path
if exe is None:
Expand All @@ -529,7 +531,8 @@ def runit(worker_addr, rabit_args):
with LocalCluster(n_workers=4) as cluster:
with Client(cluster) as client:
workers = list(xgb.dask._get_client_workers(client).keys())
rabit_args = xgb.dask._get_rabit_args(workers, client)
rabit_args = client.sync(
xgb.dask._get_rabit_args, workers, client)
futures = client.map(runit,
workers,
pure=False,
Expand Down

0 comments on commit 68e5169

Please sign in to comment.