Skip to content

Commit

Permalink
Hotfix/random proposals until valid evals (#118)
Browse files Browse the repository at this point in the history
* fix sampling of random specs until enough valid evals reported
* simplify num_initial_random check

---------

Signed-off-by: Grossberger Lukas (CR/AIR2.2) <[email protected]>
  • Loading branch information
LGro authored Dec 19, 2023
1 parent 36cd141 commit 2dce548
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
12 changes: 8 additions & 4 deletions blackboxopt/optimizers/botorch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ def __init__(
num_initial_random_samples: Size of the initial space-filling design that
is used before starting BO. The points are sampled randomly in the
search space. If no random sampling is required, set it to 0.
When random sampling is enabled, but evaluations with missing objective
values are reported, more specifications are sampled until
`num_initial_random_samples` many valid evaluations were reported.
max_pending_evaluations: Maximum number of parallel evaluations. For
sequential BO use the default value of 1. If no limit is required,
set it to None.
Expand Down Expand Up @@ -275,17 +278,18 @@ def generate_evaluation_specification(self) -> EvaluationSpecification:
):
raise OptimizerNotReady

# Generate random samples until there are enough samples where at least one of
# the objective values is available
if self.num_initial_random > 0 and (
self.X.size(-2) < self.num_initial_random
or torch.nonzero(~torch.any(self.losses.isnan(), dim=1)).numel() == 0
sum(~torch.any(self.losses.isnan(), dim=1)) < self.num_initial_random
):
# We keep generating random samples until there are enough samples, and
# at least one of them has a valid objective
eval_spec = EvaluationSpecification(
configuration=self.search_space.sample(),
optimizer_info={"model_based_pick": False},
)
else:
eval_spec = self._generate_evaluation_specification()
eval_spec.optimizer_info["model_based_pick"] = True

eval_id = self.X.size(-2) + len(self.pending_specifications)
eval_spec.optimizer_info["evaluation_id"] = eval_id
Expand Down
50 changes: 50 additions & 0 deletions tests/optimizers/botorch_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,56 @@ def test_find_optimum_in_1d_discrete_space(seed):
assert opt.objective.name in best.objectives


def test_propose_random_until_enough_evaluations_without_missing_objective_values(seed):
space = ps.ParameterSpace()
space.add(ps.IntegerParameter("integ", (0, 2)))
batch_shape = torch.Size()

opt = SingleObjectiveBOTorchOptimizer(
search_space=space,
objective=Objective("loss", greater_is_better=False),
model=SingleTaskGP(
torch.empty((*batch_shape, 0, len(space)), dtype=torch.float64),
torch.empty((*batch_shape, 0, 1), dtype=torch.float64),
),
acquisition_function_factory=partial(
UpperConfidenceBound, beta=1.0, maximize=False
),
num_initial_random_samples=2,
max_pending_evaluations=1,
seed=seed,
)

es = opt.generate_evaluation_specification()
assert not es.optimizer_info[
"model_based_pick"
], "No evaluation reported, 0 < 2 initial random samples"
opt.report(
es.create_evaluation(objectives={"loss": es.configuration["integ"] ** 2}),
)

es = opt.generate_evaluation_specification()
assert not es.optimizer_info[
"model_based_pick"
], "One evaluation reported, 1 < 2 initial random samples"
opt.report(
es.create_evaluation(objectives={"loss": None}),
)

es = opt.generate_evaluation_specification()
assert not es.optimizer_info[
"model_based_pick"
], "One valid evaluation reported, 1 < 2 initial random samples"
opt.report(
es.create_evaluation(objectives={"loss": es.configuration["integ"] ** 2}),
)

es = opt.generate_evaluation_specification()
assert es.optimizer_info[
"model_based_pick"
], "Two valid evaluations reported, 2 >= 2 initial random samples"


def test_get_numerical_points_from_discrete_space():
p0l, p0h = -5, 10
p1 = ("small", "medium", "large")
Expand Down

0 comments on commit 2dce548

Please sign in to comment.