diff --git a/python/cuml/tests/dask/test_dask_logistic_regression.py b/python/cuml/tests/dask/test_dask_logistic_regression.py index 4fe7504ce4..a512d78d4f 100644 --- a/python/cuml/tests/dask/test_dask_logistic_regression.py +++ b/python/cuml/tests/dask/test_dask_logistic_regression.py @@ -303,14 +303,7 @@ def assert_params( ) -@pytest.mark.mg -@pytest.mark.parametrize("nrows", [1e5]) -@pytest.mark.parametrize("ncols", [20]) -@pytest.mark.parametrize("n_parts", [2, 23]) -@pytest.mark.parametrize("fit_intercept", [False, True]) -@pytest.mark.parametrize("datatype", [np.float32]) -@pytest.mark.parametrize("delayed", [True, False]) -def test_lbfgs( +def _test_lbfgs( nrows, ncols, n_parts, @@ -428,9 +421,25 @@ def array_to_numpy(ary): return lr +@pytest.mark.mg +@pytest.mark.parametrize("n_parts", [2, 23]) +@pytest.mark.parametrize("fit_intercept", [False, True]) +@pytest.mark.parametrize("delayed", [True, False]) +def test_lbfgs(n_parts, fit_intercept, delayed, client): + _test_lbfgs( + nrows=1e5, + ncols=20, + n_parts=n_parts, + fit_intercept=fit_intercept, + datatype=np.float32, + delayed=delayed, + client=client, + ) + + @pytest.mark.parametrize("fit_intercept", [False, True]) def test_noreg(fit_intercept, client): - lr = test_lbfgs( + lr = _test_lbfgs( nrows=1e5, ncols=20, n_parts=23, @@ -494,7 +503,7 @@ def assert_small(X, y, n_classes): @pytest.mark.parametrize("n_classes", [8]) def test_n_classes(n_parts, fit_intercept, n_classes, client): nrows = int(1e5) if n_classes < 5 else int(2e5) - lr = test_lbfgs( + lr = _test_lbfgs( nrows=nrows, ncols=20, n_parts=n_parts, @@ -517,7 +526,7 @@ def test_n_classes(n_parts, fit_intercept, n_classes, client): @pytest.mark.parametrize("C", [1.0, 10.0]) def test_l1(fit_intercept, datatype, delayed, n_classes, C, client): nrows = int(1e5) if n_classes < 5 else int(2e5) - lr = test_lbfgs( + lr = _test_lbfgs( nrows=nrows, ncols=20, n_parts=2, @@ -545,7 +554,7 @@ def test_elasticnet( fit_intercept, datatype, delayed, n_classes, l1_ratio, client ): nrows = int(1e5) if n_classes < 5 else int(2e5) - lr = test_lbfgs( + lr = _test_lbfgs( nrows=nrows, ncols=20, n_parts=2, @@ -585,7 +594,7 @@ def test_sparse_from_dense( nrows = int(1e5) if n_classes < 5 else int(2e5) run_test = partial( - test_lbfgs, + _test_lbfgs, nrows=nrows, ncols=20, n_parts=2, @@ -699,7 +708,7 @@ def test_standardization_on_normal_dataset( nrows = int(1e5) if n_classes < 5 else int(2e5) # test correctness compared with scikit-learn - test_lbfgs( + _test_lbfgs( nrows=nrows, ncols=20, n_parts=2,