Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jungtaekkim committed Jun 21, 2024
1 parent cd3c39f commit a554066
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions tests/common/test_bo_bo_w_trees.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#
# author: Jungtaek Kim ([email protected])
# last updated: August 16, 2023
# author: Jungtaek Kim ([email protected])
# last updated: June 21, 2024
#
"""test_bo_bo_w_trees"""

Expand All @@ -14,7 +14,6 @@


BO = package_target.BOwTrees
TEST_EPSILON = 1e-5


def test_load_bo():
Expand Down Expand Up @@ -229,8 +228,7 @@ def test_get_samples():
],
]
)

assert (np.abs(arr_initials - truth_arr_initials) < TEST_EPSILON).all()
np.testing.assert_allclose(arr_initials, truth_arr_initials)

arr_initials_ = model_bo.get_samples("uniform", num_samples=3)
arr_initials = model_bo.get_samples("uniform", num_samples=3, seed=42)
Expand All @@ -241,7 +239,7 @@ def test_get_samples():
[0.58083612, 1.46470458, 1.01115012],
]
)
assert (np.abs(arr_initials - truth_arr_initials) < TEST_EPSILON).all()
np.testing.assert_allclose(arr_initials, truth_arr_initials)

arr_initials_ = model_bo.get_samples("gaussian", num_samples=3)
arr_initials = model_bo.get_samples("gaussian", num_samples=3, seed=42)
Expand All @@ -252,7 +250,7 @@ def test_get_samples():
[8.948032038768478, 0.7674347291529088, -1.1736859648373803],
]
)
assert (np.abs(arr_initials - truth_arr_initials) < TEST_EPSILON).all()
np.testing.assert_allclose(arr_initials, truth_arr_initials)


def test_get_initials():
Expand Down Expand Up @@ -312,8 +310,7 @@ def test_get_initials():
],
]
)

assert (np.abs(arr_initials - truth_arr_initials) < TEST_EPSILON).all()
np.testing.assert_allclose(arr_initials, truth_arr_initials)

arr_initials = model_bo.get_initials("uniform", 3, seed=42)
truth_arr_initials = np.array(
Expand All @@ -323,7 +320,7 @@ def test_get_initials():
[0.58083612, 1.46470458, 1.01115012],
]
)
assert (np.abs(arr_initials - truth_arr_initials) < TEST_EPSILON).all()
np.testing.assert_allclose(arr_initials, truth_arr_initials)


def test_optimize():
Expand Down

0 comments on commit a554066

Please sign in to comment.