diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 86bc064e0829..dae1a8338821 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -11,6 +11,7 @@ from urllib.parse import urlparse import pytest +from sklearn.metrics import accuracy_score, r2_score import lightgbm as lgb @@ -75,6 +76,13 @@ def cluster2(): dask_cluster.close() +@pytest.fixture(scope='module') +def cluster_three_workers(): + dask_cluster = LocalCluster(n_workers=3, threads_per_worker=1, dashboard_address=None) + yield dask_cluster + dask_cluster.close() + + @pytest.fixture() def listen_port(): listen_port.port += 10 @@ -1503,56 +1511,54 @@ def f(part): @pytest.mark.parametrize('task', tasks) @pytest.mark.parametrize('output', data_output) -def test_training_succeeds_even_if_some_workers_do_not_have_any_data(task, output, cluster): - pytest.skip("skipping due to timeout issues discussed in https://github.com/microsoft/LightGBM/pull/5510") +def test_training_succeeds_even_if_some_workers_do_not_have_any_data(task, output, cluster_three_workers): if task == 'ranking' and output == 'scipy_csr_matrix': pytest.skip('LGBMRanker is not currently tested on sparse matrices') - with Client(cluster) as client: - def collection_to_single_partition(collection): - """Merge the parts of a Dask collection into a single partition.""" - if collection is None: - return - if isinstance(collection, da.Array): - return collection.rechunk(*collection.shape) - return collection.repartition(npartitions=1) - - X, y, w, g, dX, dy, dw, dg = _create_data( + with Client(cluster_three_workers) as client: + _, y, _, _, dX, dy, dw, dg = _create_data( objective=task, output=output, - group=None + group=None, + n_samples=1_000, + chunk_size=200, ) dask_model_factory = task_to_dask_factory[task] - local_model_factory = task_to_local_factory[task] - dX = collection_to_single_partition(dX) - dy = collection_to_single_partition(dy) - dw = collection_to_single_partition(dw) - dg = collection_to_single_partition(dg) + workers = list(client.scheduler_info()['workers'].keys()) + assert len(workers) == 3 + first_two_workers = workers[:2] - n_workers = len(client.scheduler_info()['workers']) - assert n_workers > 1 - assert dX.npartitions == 1 + dX = client.persist(dX, workers=first_two_workers) + dy = client.persist(dy, workers=first_two_workers) + dw = client.persist(dw, workers=first_two_workers) + wait([dX, dy, dw]) + + workers_with_data = set() + for coll in (dX, dy, dw): + for with_data in client.who_has(coll).values(): + workers_with_data.update(with_data) + assert workers[2] not in with_data + assert len(workers_with_data) == 2 params = { 'time_out': 5, 'random_state': 42, - 'num_leaves': 10 + 'num_leaves': 10, + 'n_estimators': 20, } dask_model = dask_model_factory(tree='data', client=client, **params) dask_model.fit(dX, dy, group=dg, sample_weight=dw) dask_preds = dask_model.predict(dX).compute() - - local_model = local_model_factory(**params) - if task == 'ranking': - local_model.fit(X, y, group=g, sample_weight=w) + if task == 'regression': + score = r2_score(y, dask_preds) + elif task.endswith('classification'): + score = accuracy_score(y, dask_preds) else: - local_model.fit(X, y, sample_weight=w) - local_preds = local_model.predict(X) - - assert assert_eq(dask_preds, local_preds) + score = spearmanr(dask_preds, y).correlation + assert score > 0.9 @pytest.mark.parametrize('task', tasks)