-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Extract dask and spark test into distributed test. (#8395)
- Move test files. - Run spark and dask separately to prevent conflicts. - Gather common code into the testing module.
- Loading branch information
1 parent
f73520b
commit cfd2a9f
Showing
34 changed files
with
406 additions
and
338 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
"""Strategies for updater tests.""" | ||
|
||
from typing import cast | ||
|
||
import pytest | ||
|
||
hypothesis = pytest.importorskip("hypothesis") | ||
from hypothesis import strategies # pylint:disable=wrong-import-position | ||
|
||
exact_parameter_strategy = strategies.fixed_dictionaries( | ||
{ | ||
"nthread": strategies.integers(1, 4), | ||
"max_depth": strategies.integers(1, 11), | ||
"min_child_weight": strategies.floats(0.5, 2.0), | ||
"alpha": strategies.floats(1e-5, 2.0), | ||
"lambda": strategies.floats(1e-5, 2.0), | ||
"eta": strategies.floats(0.01, 0.5), | ||
"gamma": strategies.floats(1e-5, 2.0), | ||
"seed": strategies.integers(0, 10), | ||
# We cannot enable subsampling as the training loss can increase | ||
# 'subsample': strategies.floats(0.5, 1.0), | ||
"colsample_bytree": strategies.floats(0.5, 1.0), | ||
"colsample_bylevel": strategies.floats(0.5, 1.0), | ||
} | ||
) | ||
|
||
hist_parameter_strategy = strategies.fixed_dictionaries( | ||
{ | ||
"max_depth": strategies.integers(1, 11), | ||
"max_leaves": strategies.integers(0, 1024), | ||
"max_bin": strategies.integers(2, 512), | ||
"grow_policy": strategies.sampled_from(["lossguide", "depthwise"]), | ||
"min_child_weight": strategies.floats(0.5, 2.0), | ||
# We cannot enable subsampling as the training loss can increase | ||
# 'subsample': strategies.floats(0.5, 1.0), | ||
"colsample_bytree": strategies.floats(0.5, 1.0), | ||
"colsample_bylevel": strategies.floats(0.5, 1.0), | ||
} | ||
).filter( | ||
lambda x: (cast(int, x["max_depth"]) > 0 or cast(int, x["max_leaves"]) > 0) | ||
and (cast(int, x["max_depth"]) > 0 or x["grow_policy"] == "lossguide") | ||
) | ||
|
||
cat_parameter_strategy = strategies.fixed_dictionaries( | ||
{ | ||
"max_cat_to_onehot": strategies.integers(1, 128), | ||
"max_cat_threshold": strategies.integers(1, 128), | ||
} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
"""Testing code shared by other tests.""" | ||
# pylint: disable=invalid-name | ||
import collections | ||
import importlib.util | ||
import json | ||
import os | ||
import tempfile | ||
from typing import Any, Callable, Dict, Type | ||
|
||
import numpy as np | ||
from xgboost._typing import ArrayLike | ||
|
||
import xgboost as xgb | ||
|
||
|
||
def validate_leaf_output(leaf: np.ndarray, num_parallel_tree: int) -> None: | ||
"""Validate output for predict leaf tests.""" | ||
for i in range(leaf.shape[0]): # n_samples | ||
for j in range(leaf.shape[1]): # n_rounds | ||
for k in range(leaf.shape[2]): # n_classes | ||
tree_group = leaf[i, j, k, :] | ||
assert tree_group.shape[0] == num_parallel_tree | ||
# No sampling, all trees within forest are the same | ||
assert np.all(tree_group == tree_group[0]) | ||
|
||
|
||
def validate_data_initialization( | ||
dmatrix: Type, model: Type[xgb.XGBModel], X: ArrayLike, y: ArrayLike | ||
) -> None: | ||
"""Assert that we don't create duplicated DMatrix.""" | ||
|
||
old_init = dmatrix.__init__ | ||
count = [0] | ||
|
||
def new_init(self: Any, **kwargs: Any) -> Callable: | ||
count[0] += 1 | ||
return old_init(self, **kwargs) | ||
|
||
dmatrix.__init__ = new_init | ||
model(n_estimators=1).fit(X, y, eval_set=[(X, y)]) | ||
|
||
assert count[0] == 1 | ||
count[0] = 0 # only 1 DMatrix is created. | ||
|
||
y_copy = y.copy() | ||
model(n_estimators=1).fit(X, y, eval_set=[(X, y_copy)]) | ||
assert count[0] == 2 # a different Python object is considered different | ||
|
||
dmatrix.__init__ = old_init | ||
|
||
|
||
# pylint: disable=too-many-arguments,too-many-locals | ||
def get_feature_weights( | ||
X: ArrayLike, | ||
y: ArrayLike, | ||
fw: np.ndarray, | ||
parser_path: str, | ||
tree_method: str, | ||
model: Type[xgb.XGBModel] = xgb.XGBRegressor, | ||
) -> np.ndarray: | ||
"""Get feature weights using the demo parser.""" | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
colsample_bynode = 0.5 | ||
reg = model(tree_method=tree_method, colsample_bynode=colsample_bynode) | ||
|
||
reg.fit(X, y, feature_weights=fw) | ||
model_path = os.path.join(tmpdir, "model.json") | ||
reg.save_model(model_path) | ||
with open(model_path, "r", encoding="utf-8") as fd: | ||
model = json.load(fd) | ||
|
||
spec = importlib.util.spec_from_file_location("JsonParser", parser_path) | ||
assert spec is not None | ||
jsonm = importlib.util.module_from_spec(spec) | ||
assert spec.loader is not None | ||
spec.loader.exec_module(jsonm) | ||
model = jsonm.Model(model) | ||
splits: Dict[int, int] = {} | ||
total_nodes = 0 | ||
for tree in model.trees: | ||
n_nodes = len(tree.nodes) | ||
total_nodes += n_nodes | ||
for n in range(n_nodes): | ||
if tree.is_leaf(n): | ||
continue | ||
if splits.get(tree.split_index(n), None) is None: | ||
splits[tree.split_index(n)] = 1 | ||
else: | ||
splits[tree.split_index(n)] += 1 | ||
|
||
od = collections.OrderedDict(sorted(splits.items())) | ||
tuples = list(od.items()) | ||
k, v = list(zip(*tuples)) | ||
w = np.polyfit(k, v, deg=1) | ||
return w |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,3 +17,5 @@ dependencies: | |
- isort | ||
- pyspark | ||
- cloudpickle | ||
- pytest | ||
- hypothesis |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.