Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jungtaekkim committed Nov 15, 2024
1 parent 1843724 commit e6567ac
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 57 deletions.
73 changes: 21 additions & 52 deletions tests/common/test_trees_trees_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#
# author: Jungtaek Kim ([email protected])
# last updated: August 17, 2021
# author: Jungtaek Kim ([email protected])
# last updated: November 15, 2024
#
"""test_trees_trees_common"""

Expand All @@ -12,9 +12,6 @@
from bayeso.trees import trees_common as package_target


TEST_EPSILON = 1e-7


def test_get_inputs_from_leaf_typing():
annos = package_target.get_inputs_from_leaf.__annotations__

Expand Down Expand Up @@ -136,7 +133,7 @@ def test_mse():
assert output == 1e8 + 0.015

output = package_target.mse((left, right))
assert np.abs(output - (0.021875 + 0.015)) < TEST_EPSILON
np.testing.assert_allclose(output, 0.021875 + 0.015)


def test_subsample_typing():
Expand Down Expand Up @@ -203,8 +200,9 @@ def test_subsample():
]
)

assert np.all(np.abs(X_truth - X_) < TEST_EPSILON)
assert np.all(np.abs(Y_truth - Y_) < TEST_EPSILON)
np.testing.assert_allclose(X_, X_truth)
np.testing.assert_allclose(Y_, Y_truth)

assert X_.shape[0] == Y_.shape[0] == X_truth.shape[0] == Y_truth.shape[0]
assert X_.shape[0] == Y_.shape[0] == int(ratio_sampling * X.shape[0])

Expand Down Expand Up @@ -245,8 +243,9 @@ def test_subsample():
]
)

assert np.all(np.abs(X_truth - X_) < TEST_EPSILON)
assert np.all(np.abs(Y_truth - Y_) < TEST_EPSILON)
np.testing.assert_allclose(X_, X_truth)
np.testing.assert_allclose(Y_, Y_truth)

assert X_.shape[0] == Y_.shape[0] == X_truth.shape[0] == Y_truth.shape[0]
assert X_.shape[0] == Y_.shape[0] == int(1.2 * X.shape[0])

Expand Down Expand Up @@ -346,64 +345,34 @@ def test__split():
assert dict_split["value"] == 35.0

assert np.all(dict_split["left_right"][0][0][0] == np.array([0, 1, 2, 3]))
assert (
np.abs(dict_split["left_right"][0][0][1] - np.array([0.49671415]))
< TEST_EPSILON
)
np.testing.assert_allclose(dict_split["left_right"][0][0][1], np.array([0.49671415]))

assert np.all(dict_split["left_right"][0][1][0] == np.array([4, 5, 6, 7]))
assert (
np.abs(dict_split["left_right"][0][1][1] - np.array([-0.1382643]))
< TEST_EPSILON
)
np.testing.assert_allclose(dict_split["left_right"][0][1][1], np.array([-0.1382643]))

assert np.all(dict_split["left_right"][0][2][0] == np.array([8, 9, 10, 11]))
assert (
np.abs(dict_split["left_right"][0][2][1] - np.array([0.64768854]))
< TEST_EPSILON
)
np.testing.assert_allclose(dict_split["left_right"][0][2][1], np.array([0.64768854]))

assert np.all(dict_split["left_right"][0][3][0] == np.array([12, 13, 14, 15]))
assert (
np.abs(dict_split["left_right"][0][3][1] - np.array([1.52302986]))
< TEST_EPSILON
)
np.testing.assert_allclose(dict_split["left_right"][0][3][1], np.array([1.52302986]))

assert np.all(dict_split["left_right"][0][4][0] == np.array([16, 17, 18, 19]))
assert (
np.abs(dict_split["left_right"][0][4][1] - np.array([-0.23415337]))
< TEST_EPSILON
)
np.testing.assert_allclose(dict_split["left_right"][0][4][1], np.array([-0.23415337]))

assert np.all(dict_split["left_right"][0][5][0] == np.array([20, 21, 22, 23]))
assert (
np.abs(dict_split["left_right"][0][5][1] - np.array([-0.23413696]))
< TEST_EPSILON
)
np.testing.assert_allclose(dict_split["left_right"][0][5][1], np.array([-0.23413696]))

assert np.all(dict_split["left_right"][0][6][0] == np.array([24, 25, 26, 27]))
assert (
np.abs(dict_split["left_right"][0][6][1] - np.array([1.57921282]))
< TEST_EPSILON
)
np.testing.assert_allclose(dict_split["left_right"][0][6][1], np.array([1.57921282]))

assert np.all(dict_split["left_right"][0][7][0] == np.array([28, 29, 30, 31]))
assert (
np.abs(dict_split["left_right"][0][7][1] - np.array([0.76743473]))
< TEST_EPSILON
)
np.testing.assert_allclose(dict_split["left_right"][0][7][1], np.array([0.76743473]))

assert np.all(dict_split["left_right"][0][8][0] == np.array([32, 33, 34, 35]))
assert (
np.abs(dict_split["left_right"][0][8][1] - np.array([-0.46947439]))
< TEST_EPSILON
)
np.testing.assert_allclose(dict_split["left_right"][0][8][1], np.array([-0.46947439]))

assert np.all(dict_split["left_right"][1][0][0] == np.array([36, 37, 38, 39]))
assert (
np.abs(dict_split["left_right"][1][0][1] - np.array([0.54256004]))
< TEST_EPSILON
)
np.testing.assert_allclose(dict_split["left_right"][1][0][1], np.array([0.54256004]))

dict_split = package_target._split(X, Y, num_features, True)
print(dict_split)
Expand Down Expand Up @@ -666,8 +635,8 @@ def test_predict_by_trees():
assert means.shape[0] == stds.shape[0] == X.shape[0]
assert means.shape[1] == stds.shape[1] == 1

assert np.all(np.abs(means - means_truth) < TEST_EPSILON)
assert np.all(np.abs(stds - stds_truth) < TEST_EPSILON)
np.testing.assert_allclose(means, means_truth)
np.testing.assert_allclose(stds, stds_truth)

X = np.random.randn(1000, 4)

Expand Down
7 changes: 2 additions & 5 deletions tests/common/test_trees_trees_generic_trees.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#
# author: Jungtaek Kim ([email protected])
# last updated: August 20, 2021
# author: Jungtaek Kim ([email protected])
# last updated: November 15, 2024
#
"""test_trees_trees_generic_trees"""

Expand All @@ -12,9 +12,6 @@
from bayeso.trees import trees_generic_trees as package_target


TEST_EPSILON = 1e-7


def test_get_generic_trees_typing():
annos = package_target.get_generic_trees.__annotations__

Expand Down

0 comments on commit e6567ac

Please sign in to comment.