Skip to content

Commit

Permalink
Merge pull request #4997 from rapidsai/branch-22.12
Browse files Browse the repository at this point in the history
[gpuCI] Forward-merge branch-22.12 to branch-23.02 [skip gpuci]
  • Loading branch information
GPUtester authored Nov 15, 2022
2 parents e3dcfd3 + 6500897 commit 1b61db0
Show file tree
Hide file tree
Showing 6 changed files with 508 additions and 4 deletions.
2 changes: 1 addition & 1 deletion cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ function(find_and_configure_raft)
GLOBAL_TARGETS raft::raft
BUILD_EXPORT_SET cuml-exports
INSTALL_EXPORT_SET cuml-exports
COMPONENTS ${RAFT_COMPONENTS}
CPM_ARGS
GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git
GIT_TAG ${PKG_PINNED_TAG}
SOURCE_SUBDIR cpp
FIND_PACKAGE_ARGUMENTS "COMPONENTS ${RAFT_COMPONENTS}"
OPTIONS
"BUILD_TESTS OFF"
"RAFT_COMPILE_LIBRARIES ${RAFT_COMPILE_LIBRARIES}"
Expand Down
250 changes: 250 additions & 0 deletions python/cuml/testing/strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from hypothesis import assume
from hypothesis.extra.numpy import arrays, floating_dtypes
from hypothesis.strategies import composite, integers, just, none, one_of
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split


@composite
def standard_datasets(
draw,
dtypes=floating_dtypes(),
n_samples=integers(min_value=0, max_value=200),
n_features=integers(min_value=0, max_value=200),
*,
n_targets=just(1),
):
"""
Returns a strategy to generate standard estimator input datasets.
Parameters
----------
dtypes: SearchStrategy[np.dtype], default=floating_dtypes()
Returned arrays will have a dtype drawn from these types.
n_samples: SearchStrategy[int], \
default=integers(min_value=0, max_value=200)
Returned arrays will have number of rows drawn from these values.
n_features: SearchStrategy[int], \
default=integers(min_value=0, max_values=200)
Returned arrays will have number of columns drawn from these values.
n_targets: SearchStrategy[int], default=just(1)
Determines the number of targets returned datasets may contain.
Returns
-------
X: SearchStrategy[array] (n_samples, n_features)
The search strategy for input samples.
y: SearchStrategy[array] (n_samples,) or (n_samples, n_targets)
The search strategy for output samples.
"""
xs = draw(n_samples)
ys = draw(n_features)
X = arrays(dtype=dtypes, shape=(xs, ys))
y = arrays(dtype=dtypes, shape=(xs, draw(n_targets)))
return draw(X), draw(y)


def combined_datasets_strategy(* datasets, name=None, doc=None):
"""
Combine multiple datasets strategies into a single datasets strategy.
This function will return a new strategy that will build the provided
strategy functions with the common parameters (dtypes, n_samples,
n_features) and then draw from one of them.
Parameters:
-----------
* datasets: list[Callable[[dtypes, n_samples, n_features], SearchStrategy]]
A list of functions that return a dataset search strategy when called
with the shown arguments.
name: The name of the returned search strategy, default="datasets"
Defaults to a combination of names of the provided dataset strategy
functions.
doc: The doc-string of the returned search strategy, default=None
Defaults to a generic doc-string.
Returns
-------
Datasets search strategy: SearchStrategy[array], SearchStrategy[array]
"""

@composite
def strategy(
draw,
dtypes=floating_dtypes(),
n_samples=integers(min_value=0, max_value=200),
n_features=integers(min_value=0, max_value=200)
):
"""Datasets strategy composed of multiple datasets strategies."""
datasets_strategies = (
dataset(dtypes, n_samples, n_features) for dataset in datasets)
return draw(one_of(datasets_strategies))

strategy.__name__ = "datasets" if name is None else name
if doc is not None:
strategy.__doc__ = doc

return strategy


@composite
def split_datasets(
draw,
datasets,
test_sizes=None,
):
"""
Split a generic search strategy for datasets into test and train subsets.
The resulting split is guaranteed to have at least one sample in both the
train and test split respectively.
Note: This function uses the sklearn.model_selection.train_test_split
function.
See also:
standard_datasets(): A search strategy for datasets that can serve as input
to this strategy.
Parameters
----------
datasets: SearchStrategy[dataset]
A search strategy for datasets.
test_sizes: SearchStrategy[int | float], default=None
A search strategy for the test size. Must be provided as a search
strategy for integers or floats. Integers should be bound by one and
the sample size, floats should be between 0 and 1.0. Defaults to
a search strategy that will generate a valid unbiased split.
Returns
-------
(X_train, X_test, y_train, y_test): tuple[SearchStrategy[array], ...]
The train-test split of the input and output samples drawn from
the provided datasets search strategy.
"""
X, y = draw(datasets)
assume(len(X) > 1)

# Determine default value for test_sizes
if test_sizes is None:
test_sizes = integers(1, max(1, len(X) - 1))

test_size = draw(test_sizes)

# Check assumptions for test_size
if isinstance(test_size, float):
assume(int(len(X) * test_size) > 0)
assume(int(len(X) * (1.0 - test_size)) > 0)
elif isinstance(test_size, int):
assume(1 < test_size < len(X))

return train_test_split(X, y, test_size=test_size)


@composite
def standard_regression_datasets(
draw,
dtypes=floating_dtypes(),
n_samples=integers(min_value=100, max_value=200),
n_features=integers(min_value=100, max_value=200),
*,
n_informative=None,
n_targets=just(1),
bias=just(0.0),
effective_rank=none(),
tail_strength=just(0.5),
noise=just(0.0),
shuffle=just(True),
random_state=None,
):
"""
Returns a strategy to generate regression problem input datasets.
Note:
This function uses the sklearn.datasets.make_regression function to
generate the regression problem from the provided search strategies.
Parameters
----------
dtypes: SearchStrategy[np.dtype]
Returned arrays will have a dtype drawn from these types.
n_samples: SearchStrategy[int]
Returned arrays will have number of rows drawn from these values.
n_features: SearchStrategy[int]
Returned arrays will have number of columns drawn from these values.
n_informative: SearchStrategy[int], default=none
A search strategy for the number of informative features. If none,
will use 10% of the actual number of features, but not less than 1
unless the number of features is zero.
n_targets: SearchStrategy[int], default=just(1)
A search strategy for the number of targets, that means the number of
columns of the returned y output array.
bias: SearchStrategy[float], default=just(0.0)
A search strategy for the bias term.
effective_rank=none()
If not None, a search strategy for the effective rank of the input data
for the regression problem. See sklearn.dataset.make_regression() for a
detailed explanation of this parameter.
tail_strength: SearchStrategy[float], default=just(0.5)
See sklearn.dataset.make_regression() for a detailed explanation of
this parameter.
noise: SearchStrategy[float], default=just(0.0)
A search strategy for the standard deviation of the gaussian noise.
shuffle: SearchStrategy[bool], default=just(True)
A boolean search strategy to determine whether samples and features
are shuffled.
random_state: int, RandomState instance or None, default=None
Pass a random state or integer to determine the random number
generation for data set generation.
Returns
-------
(X, y): SearchStrategy[array], SearchStrategy[array]
A tuple of search strategies for arrays subject to the constraints of
the provided parameters.
"""
n_features_ = draw(n_features)
if n_informative is None:
n_informative = just(max(min(n_features_, 1), int(0.1 * n_features_)))
X, y = make_regression(
n_samples=draw(n_samples),
n_features=n_features_,
n_informative=draw(n_informative),
n_targets=draw(n_targets),
bias=draw(bias),
effective_rank=draw(effective_rank),
tail_strength=draw(tail_strength),
noise=draw(noise),
shuffle=draw(shuffle),
random_state=random_state,
)
dtype_ = draw(dtypes)
return X.astype(dtype_), y.astype(dtype_)


regression_datasets = combined_datasets_strategy(
standard_datasets, standard_regression_datasets,
name="regression_datasets",
doc="""
Returns strategy for the generation of regression problem datasets.
Drawn from the standard_datasets and the standard_regression_datasets
strategies.
"""
)
36 changes: 36 additions & 0 deletions python/cuml/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import cupy as cp
import hypothesis

from math import ceil
from sklearn.datasets import fetch_20newsgroups
Expand All @@ -34,6 +35,30 @@
pytest_plugins = ("cuml.testing.plugins.quick_run_plugin")


# Configure hypothesis profiles

hypothesis.settings.register_profile(
name="unit",
parent=hypothesis.settings.get_profile("default"),
max_examples=20,
suppress_health_check=[
hypothesis.HealthCheck.data_too_large,
],
)

hypothesis.settings.register_profile(
name="quality",
parent=hypothesis.settings.get_profile("unit"),
max_examples=100,
)

hypothesis.settings.register_profile(
name="stress",
parent=hypothesis.settings.get_profile("quality"),
max_examples=200
)


def pytest_addoption(parser):
# Any custom option, that should be available at any time (not just a
# plugin), goes here.
Expand Down Expand Up @@ -101,6 +126,17 @@ def pytest_configure(config):
pytest.max_gpu_memory = get_gpu_memory()
pytest.adapt_stress_test = 'CUML_ADAPT_STRESS_TESTS' in os.environ

# Load special hypothesis profiles for either quality or stress tests.
# Note that the profile can be manually overwritten with the
# --hypothesis-profile command line option in which case the settings
# specified here will be ignored.
if config.getoption("--run_stress"):
hypothesis.settings.load_profile("stress")
elif config.getoption("--run_quality"):
hypothesis.settings.load_profile("quality")
else:
hypothesis.settings.load_profile("unit")


@pytest.fixture(scope="module")
def nlp_20news():
Expand Down
Loading

0 comments on commit 1b61db0

Please sign in to comment.