From 96bebadc533104bfb8f3abf11d6bf0edf151c171 Mon Sep 17 00:00:00 2001 From: Carl Simon Adorf Date: Fri, 11 Nov 2022 21:10:17 +0100 Subject: [PATCH 1/2] Implement hypothesis-based tests for linear models (#4952) Closes #4943. Authors: - Carl Simon Adorf (https://github.com/csadorf) Approvers: - William Hicks (https://github.com/wphicks) URL: https://github.com/rapidsai/cuml/pull/4952 --- python/cuml/testing/strategies.py | 250 +++++++++++++++++++++++++ python/cuml/tests/conftest.py | 36 ++++ python/cuml/tests/test_linear_model.py | 93 ++++++++- python/cuml/tests/test_strategies.py | 129 +++++++++++++ wiki/python/DEVELOPER_GUIDE.md | 2 + 5 files changed, 507 insertions(+), 3 deletions(-) create mode 100644 python/cuml/testing/strategies.py create mode 100644 python/cuml/tests/test_strategies.py diff --git a/python/cuml/testing/strategies.py b/python/cuml/testing/strategies.py new file mode 100644 index 0000000000..8d4c1cc335 --- /dev/null +++ b/python/cuml/testing/strategies.py @@ -0,0 +1,250 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from hypothesis import assume +from hypothesis.extra.numpy import arrays, floating_dtypes +from hypothesis.strategies import composite, integers, just, none, one_of +from sklearn.datasets import make_regression +from sklearn.model_selection import train_test_split + + +@composite +def standard_datasets( + draw, + dtypes=floating_dtypes(), + n_samples=integers(min_value=0, max_value=200), + n_features=integers(min_value=0, max_value=200), + *, + n_targets=just(1), +): + """ + Returns a strategy to generate standard estimator input datasets. + + Parameters + ---------- + dtypes: SearchStrategy[np.dtype], default=floating_dtypes() + Returned arrays will have a dtype drawn from these types. + n_samples: SearchStrategy[int], \ + default=integers(min_value=0, max_value=200) + Returned arrays will have number of rows drawn from these values. + n_features: SearchStrategy[int], \ + default=integers(min_value=0, max_values=200) + Returned arrays will have number of columns drawn from these values. + n_targets: SearchStrategy[int], default=just(1) + Determines the number of targets returned datasets may contain. + + Returns + ------- + X: SearchStrategy[array] (n_samples, n_features) + The search strategy for input samples. + y: SearchStrategy[array] (n_samples,) or (n_samples, n_targets) + The search strategy for output samples. + + """ + xs = draw(n_samples) + ys = draw(n_features) + X = arrays(dtype=dtypes, shape=(xs, ys)) + y = arrays(dtype=dtypes, shape=(xs, draw(n_targets))) + return draw(X), draw(y) + + +def combined_datasets_strategy(* datasets, name=None, doc=None): + """ + Combine multiple datasets strategies into a single datasets strategy. + + This function will return a new strategy that will build the provided + strategy functions with the common parameters (dtypes, n_samples, + n_features) and then draw from one of them. + + Parameters: + ----------- + * datasets: list[Callable[[dtypes, n_samples, n_features], SearchStrategy]] + A list of functions that return a dataset search strategy when called + with the shown arguments. + name: The name of the returned search strategy, default="datasets" + Defaults to a combination of names of the provided dataset strategy + functions. + doc: The doc-string of the returned search strategy, default=None + Defaults to a generic doc-string. + + Returns + ------- + Datasets search strategy: SearchStrategy[array], SearchStrategy[array] + """ + + @composite + def strategy( + draw, + dtypes=floating_dtypes(), + n_samples=integers(min_value=0, max_value=200), + n_features=integers(min_value=0, max_value=200) + ): + """Datasets strategy composed of multiple datasets strategies.""" + datasets_strategies = ( + dataset(dtypes, n_samples, n_features) for dataset in datasets) + return draw(one_of(datasets_strategies)) + + strategy.__name__ = "datasets" if name is None else name + if doc is not None: + strategy.__doc__ = doc + + return strategy + + +@composite +def split_datasets( + draw, + datasets, + test_sizes=None, +): + """ + Split a generic search strategy for datasets into test and train subsets. + + The resulting split is guaranteed to have at least one sample in both the + train and test split respectively. + + Note: This function uses the sklearn.model_selection.train_test_split + function. + + See also: + standard_datasets(): A search strategy for datasets that can serve as input + to this strategy. + + Parameters + ---------- + datasets: SearchStrategy[dataset] + A search strategy for datasets. + test_sizes: SearchStrategy[int | float], default=None + A search strategy for the test size. Must be provided as a search + strategy for integers or floats. Integers should be bound by one and + the sample size, floats should be between 0 and 1.0. Defaults to + a search strategy that will generate a valid unbiased split. + + Returns + ------- + (X_train, X_test, y_train, y_test): tuple[SearchStrategy[array], ...] + The train-test split of the input and output samples drawn from + the provided datasets search strategy. + """ + X, y = draw(datasets) + assume(len(X) > 1) + + # Determine default value for test_sizes + if test_sizes is None: + test_sizes = integers(1, max(1, len(X) - 1)) + + test_size = draw(test_sizes) + + # Check assumptions for test_size + if isinstance(test_size, float): + assume(int(len(X) * test_size) > 0) + assume(int(len(X) * (1.0 - test_size)) > 0) + elif isinstance(test_size, int): + assume(1 < test_size < len(X)) + + return train_test_split(X, y, test_size=test_size) + + +@composite +def standard_regression_datasets( + draw, + dtypes=floating_dtypes(), + n_samples=integers(min_value=100, max_value=200), + n_features=integers(min_value=100, max_value=200), + *, + n_informative=None, + n_targets=just(1), + bias=just(0.0), + effective_rank=none(), + tail_strength=just(0.5), + noise=just(0.0), + shuffle=just(True), + random_state=None, +): + """ + Returns a strategy to generate regression problem input datasets. + + Note: + This function uses the sklearn.datasets.make_regression function to + generate the regression problem from the provided search strategies. + + + Parameters + ---------- + dtypes: SearchStrategy[np.dtype] + Returned arrays will have a dtype drawn from these types. + n_samples: SearchStrategy[int] + Returned arrays will have number of rows drawn from these values. + n_features: SearchStrategy[int] + Returned arrays will have number of columns drawn from these values. + n_informative: SearchStrategy[int], default=none + A search strategy for the number of informative features. If none, + will use 10% of the actual number of features, but not less than 1 + unless the number of features is zero. + n_targets: SearchStrategy[int], default=just(1) + A search strategy for the number of targets, that means the number of + columns of the returned y output array. + bias: SearchStrategy[float], default=just(0.0) + A search strategy for the bias term. + effective_rank=none() + If not None, a search strategy for the effective rank of the input data + for the regression problem. See sklearn.dataset.make_regression() for a + detailed explanation of this parameter. + tail_strength: SearchStrategy[float], default=just(0.5) + See sklearn.dataset.make_regression() for a detailed explanation of + this parameter. + noise: SearchStrategy[float], default=just(0.0) + A search strategy for the standard deviation of the gaussian noise. + shuffle: SearchStrategy[bool], default=just(True) + A boolean search strategy to determine whether samples and features + are shuffled. + random_state: int, RandomState instance or None, default=None + Pass a random state or integer to determine the random number + generation for data set generation. + + Returns + ------- + (X, y): SearchStrategy[array], SearchStrategy[array] + A tuple of search strategies for arrays subject to the constraints of + the provided parameters. + """ + n_features_ = draw(n_features) + if n_informative is None: + n_informative = just(max(min(n_features_, 1), int(0.1 * n_features_))) + X, y = make_regression( + n_samples=draw(n_samples), + n_features=n_features_, + n_informative=draw(n_informative), + n_targets=draw(n_targets), + bias=draw(bias), + effective_rank=draw(effective_rank), + tail_strength=draw(tail_strength), + noise=draw(noise), + shuffle=draw(shuffle), + random_state=random_state, + ) + dtype_ = draw(dtypes) + return X.astype(dtype_), y.astype(dtype_) + + +regression_datasets = combined_datasets_strategy( + standard_datasets, standard_regression_datasets, + name="regression_datasets", + doc=""" + Returns strategy for the generation of regression problem datasets. + + Drawn from the standard_datasets and the standard_regression_datasets + strategies. + """ +) diff --git a/python/cuml/tests/conftest.py b/python/cuml/tests/conftest.py index 414ceddfb3..0515c952fa 100644 --- a/python/cuml/tests/conftest.py +++ b/python/cuml/tests/conftest.py @@ -20,6 +20,7 @@ import numpy as np import cupy as cp +import hypothesis from math import ceil from sklearn.datasets import fetch_20newsgroups @@ -34,6 +35,30 @@ pytest_plugins = ("cuml.testing.plugins.quick_run_plugin") +# Configure hypothesis profiles + +hypothesis.settings.register_profile( + name="unit", + parent=hypothesis.settings.get_profile("default"), + max_examples=20, + suppress_health_check=[ + hypothesis.HealthCheck.data_too_large, + ], +) + +hypothesis.settings.register_profile( + name="quality", + parent=hypothesis.settings.get_profile("unit"), + max_examples=100, +) + +hypothesis.settings.register_profile( + name="stress", + parent=hypothesis.settings.get_profile("quality"), + max_examples=200 +) + + def pytest_addoption(parser): # Any custom option, that should be available at any time (not just a # plugin), goes here. @@ -101,6 +126,17 @@ def pytest_configure(config): pytest.max_gpu_memory = get_gpu_memory() pytest.adapt_stress_test = 'CUML_ADAPT_STRESS_TESTS' in os.environ + # Load special hypothesis profiles for either quality or stress tests. + # Note that the profile can be manually overwritten with the + # --hypothesis-profile command line option in which case the settings + # specified here will be ignored. + if config.getoption("--run_stress"): + hypothesis.settings.load_profile("stress") + elif config.getoption("--run_quality"): + hypothesis.settings.load_profile("quality") + else: + hypothesis.settings.load_profile("unit") + @pytest.fixture(scope="module") def nlp_20news(): diff --git a/python/cuml/tests/test_linear_model.py b/python/cuml/tests/test_linear_model.py index 81a542ef0f..03a21dfada 100644 --- a/python/cuml/tests/test_linear_model.py +++ b/python/cuml/tests/test_linear_model.py @@ -16,13 +16,29 @@ import cupy as cp import numpy as np import pytest +from hypothesis import ( + assume, + example, + given, + settings, + strategies as st, + target +) +from hypothesis.extra.numpy import floating_dtypes from distutils.version import LooseVersion import cudf from cuml import ElasticNet as cuElasticNet from cuml import LinearRegression as cuLinearRegression from cuml import LogisticRegression as cuLog from cuml import Ridge as cuRidge +from cuml.common.input_utils import _typecast_will_lose_information +from cuml.testing.strategies import ( + regression_datasets, + split_datasets, + standard_regression_datasets, +) from cuml.testing.utils import ( + array_difference, array_equal, small_regression_dataset, small_classification_dataset, @@ -88,6 +104,26 @@ def make_classification_dataset(datatype, nrows, ncols, n_info, num_classes): return X_train, X_test, y_train, y_test +def sklearn_compatible_dataset(X_train, X_test, y_train, _=None): + return ( + X_train.shape[1] >= 1 + and (X_train > 0).any() + and (y_train > 0).any() + and all(np.isfinite(x).all() + for x in (X_train, X_test, y_train) if x is not None) + ) + + +def cuml_compatible_dataset(X_train, X_test, y_train, _=None): + return ( + X_train.shape[0] >= 2 + and X_train.shape[1] >= 1 + and np.isfinite(X_train).all() + and not any(_typecast_will_lose_information(x, np.float32) + for x in (X_train, X_test, y_train) if x is not None) + ) + + @pytest.mark.parametrize("datatype", [np.float32, np.float64]) @pytest.mark.parametrize("algorithm", ["eig", "svd"]) @pytest.mark.parametrize( @@ -193,10 +229,60 @@ def test_linear_regression_single_column(): model.fit(cp.random.rand(46341), cp.random.rand(46341)) -@pytest.mark.parametrize("datatype", [np.float32, np.float64]) -def test_linear_regression_model_default(datatype): +# The assumptions required to have this test pass are relatively strong. +# It should be possible to relax assumptions once #4963 is resolved. +# See also: test_linear_regression_model_default_generalized +@given( + split_datasets( + standard_regression_datasets( + dtypes=floating_dtypes(sizes=(32, 64)), + n_samples=st.just(1000), + ), + test_sizes=st.just(0.2) + ) +) +@example(small_regression_dataset(np.float32)) +@example(small_regression_dataset(np.float64)) +@settings(deadline=5000) +def test_linear_regression_model_default(dataset): - X_train, X_test, y_train, y_test = small_regression_dataset(datatype) + X_train, X_test, y_train, _ = dataset + + # Filter datasets based on required assumptions + assume(sklearn_compatible_dataset(X_train, X_test, y_train)) + assume(cuml_compatible_dataset(X_train, X_test, y_train)) + + # Initialization of cuML's linear regression model + cuols = cuLinearRegression() + + # fit and predict cuml linear regression model + cuols.fit(X_train, y_train) + cuols_predict = cuols.predict(X_test) + + # sklearn linear regression model initialization and fit + skols = skLinearRegression() + skols.fit(X_train, y_train) + + skols_predict = skols.predict(X_test) + + target(float(array_difference(skols_predict, cuols_predict))) + assert array_equal(skols_predict, cuols_predict, 1e-1, with_sign=True) + + +# TODO: Replace test_linear_regression_model_default with this test once #4963 +# is resolved. +@pytest.mark.xfail(reason="https://github.com/rapidsai/cuml/issues/4963") +@given( + split_datasets(regression_datasets(dtypes=floating_dtypes(sizes=(32, 64)))) +) +@settings(deadline=5000) +def test_linear_regression_model_default_generalized(dataset): + + X_train, X_test, y_train, _ = dataset + + # Filter datasets based on required assumptions + assume(sklearn_compatible_dataset(X_train, X_test, y_train)) + assume(cuml_compatible_dataset(X_train, X_test, y_train)) # Initialization of cuML's linear regression model cuols = cuLinearRegression() @@ -211,6 +297,7 @@ def test_linear_regression_model_default(datatype): skols_predict = skols.predict(X_test) + target(float(array_difference(skols_predict, cuols_predict))) assert array_equal(skols_predict, cuols_predict, 1e-1, with_sign=True) diff --git a/python/cuml/tests/test_strategies.py b/python/cuml/tests/test_strategies.py new file mode 100644 index 0000000000..ed38a202b1 --- /dev/null +++ b/python/cuml/tests/test_strategies.py @@ -0,0 +1,129 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from cuml.testing.strategies import ( + standard_regression_datasets, + regression_datasets, + split_datasets, + standard_datasets, +) +from hypothesis import given, settings, HealthCheck +from hypothesis import strategies as st +from hypothesis.extra.numpy import floating_dtypes + + +@given(standard_datasets()) +def test_standard_datasets_default(dataset): + X, y = dataset + + assert X.ndim == 2 + assert X.shape[0] <= 200 + assert X.shape[1] <= 200 + assert (y.ndim == 0) or (y.ndim in (1, 2) and y.shape[0] <= 200) + + +@given( + standard_datasets( + dtypes=floating_dtypes(sizes=(32,)), + n_samples=st.integers(10, 20), + n_features=st.integers(30, 40), + ) +) +def test_standard_datasets(dataset): + X, y = dataset + + assert X.ndim == 2 + assert 10 <= X.shape[0] <= 20 + assert 30 <= X.shape[1] <= 40 + assert 10 <= y.shape[0] <= 20 + assert y.shape[1] == 1 + + +@given(split_datasets(standard_datasets())) +@settings(suppress_health_check=[HealthCheck.too_slow]) +def test_split_datasets(split_dataset): + X_train, X_test, y_train, y_test = split_dataset + + assert X_train.ndim == X_test.ndim == 2 + assert X_train.shape[1] == X_test.shape[1] + assert 2 <= (len(X_train) + len(X_test)) <= 200 + + assert y_train.ndim == y_test.ndim + assert y_train.ndim in (0, 1, 2) + assert (y_train.ndim == 0) or (2 <= (len(y_train) + len(y_test)) <= 200) + + +@given(standard_regression_datasets()) +def test_standard_regression_datasets_default(dataset): + X, y = dataset + assert X.ndim == 2 + assert X.shape[0] <= 200 + assert X.shape[1] <= 200 + assert (y.ndim == 0) or (y.ndim in (1, 2) and y.shape[0] <= 200) + assert X.dtype == y.dtype + + +@given( + standard_regression_datasets( + dtypes=floating_dtypes(sizes=64), + n_samples=st.integers(min_value=2, max_value=200), + n_features=st.integers(min_value=1, max_value=200), + n_informative=st.just(10), + random_state=0, + ) +) +def test_standard_regression_datasets(dataset): + + from sklearn.datasets import make_regression + + X, y = dataset + assert X.ndim == 2 + assert X.shape[0] <= 200 + assert X.shape[1] <= 200 + assert (y.ndim == 1 and y.shape[0] <= 200) or y.ndim == 0 + assert X.dtype == y.dtype + + X_cmp, y_cmp = make_regression( + n_samples=X.shape[0], n_features=X.shape[1], random_state=0 + ) + + assert X.dtype.type == X_cmp.dtype.type + assert X.ndim == X_cmp.ndim + assert X.shape == X_cmp.shape + assert y.dtype.type == y_cmp.dtype.type + assert y.ndim == y_cmp.ndim + assert y.shape == y_cmp.shape + assert (X == X_cmp).all() + assert (y == y_cmp).all() + + +@given(regression_datasets()) +def test_regression_datasets(dataset): + X, y = dataset + + assert X.ndim == 2 + assert X.shape[0] <= 200 + assert X.shape[1] <= 200 + assert (y.ndim == 0) or (y.ndim in (1, 2) and y.shape[0] <= 200) + + +@given(split_datasets(regression_datasets())) +@settings(suppress_health_check=[HealthCheck.too_slow]) +def test_split_regression_datasets(split_dataset): + X_train, X_test, y_train, y_test = split_dataset + + assert X_train.ndim == X_test.ndim == 2 + assert y_train.ndim == y_test.ndim + assert y_train.ndim in (0, 1, 2) + assert 2 <= (len(X_train) + len(X_test)) <= 200 diff --git a/wiki/python/DEVELOPER_GUIDE.md b/wiki/python/DEVELOPER_GUIDE.md index 3bb0eb35c8..55937d4b1a 100644 --- a/wiki/python/DEVELOPER_GUIDE.md +++ b/wiki/python/DEVELOPER_GUIDE.md @@ -49,6 +49,8 @@ Examples subject to numerical imprecision, or that can't be reproduced consisten ## Testing and Unit Testing We use [https://docs.pytest.org/en/latest/]() for writing and running tests. To see existing examples, refer to any of the `test_*.py` files in the folder `cuml/tests`. +Some tests are run against inputs generated with [hypothesis](https://hypothesis.works/). See the `cuml/testing/strategies.py` module for custom strategies that can be used to test cuml estimators with diverse inputs. For example, use the `regression_datasets()` strategy to test random regression problems. + ## Device and Host memory allocations TODO: talk about enabling RMM here when it is ready From 1592c16a8199456090c573f538053ba442e77504 Mon Sep 17 00:00:00 2001 From: Robert Maynard Date: Mon, 14 Nov 2022 09:50:50 -0500 Subject: [PATCH 2/2] Use rapdsi_cpm_find(COMPONENTS ) for proper component tracking (#4989) This allows rapids-cmake to generate the correct find_dependency(raft COMPONENTS ...) required to use cuml Authors: - Robert Maynard (https://github.com/robertmaynard) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuml/pull/4989 --- cpp/cmake/thirdparty/get_raft.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 190d8d68b8..361684df79 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -55,11 +55,11 @@ function(find_and_configure_raft) GLOBAL_TARGETS raft::raft BUILD_EXPORT_SET cuml-exports INSTALL_EXPORT_SET cuml-exports + COMPONENTS ${RAFT_COMPONENTS} CPM_ARGS GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git GIT_TAG ${PKG_PINNED_TAG} SOURCE_SUBDIR cpp - FIND_PACKAGE_ARGUMENTS "COMPONENTS ${RAFT_COMPONENTS}" OPTIONS "BUILD_TESTS OFF" "RAFT_COMPILE_LIBRARIES ${RAFT_COMPILE_LIBRARIES}"