Skip to content

Commit

Permalink
Changes to Adapter constructors, require experiment (#3415)
Browse files Browse the repository at this point in the history
Summary:

The motivation for this change is to bring us closer to relying on the `experiment` as the source of truth about the experiment state & attributes within the modeling layer. We currently support many inputs that are extracted from the `experiment`, just to be passed in alongside it. By making `experiment` a required input, we open the possibility of removing these extra inputs and extracting them directly from `experiment` where they're needed.

Makes the following changes to Adapter constructors:
- Requires keyword-only arguments. Positional inputs are no-longer supported.
- Makes `experiment` required and `search_space` optional.
- Re-orders inputs for consistency across sub-classes.

In addition:
- Removes `model` input to `Adapter._fit`. This is a private method that is only called through `fit_if_implemented` (with `self.model`). Accepting multiple inputs for the same argument only makes the code harder to reason about.
- Removes class level attributes, some of which weren't initialized in `__init__`, leading to pyre complaints. All attributes are now initialized in `__init__`. This also eliminates misleading "optional" type hints with `None` default for `model`, which is never `None` in practice.
- Removes `Adapter.update`, which has been deprecated for quite some time.
- Initializing a Generator from registry with only `search_space` is being deprecated. It is temporarily supported using a dummy experiment for random & discrete adapters, which previously did not require an `experiment`.

Differential Revision: D70103442
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 25, 2025
1 parent a03b8c6 commit c0c9dba
Show file tree
Hide file tree
Showing 25 changed files with 519 additions and 662 deletions.
5 changes: 4 additions & 1 deletion ax/benchmark/problems/surrogate/lcbench/transfer_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ax.modelbridge.registry import Cont_X_trans, Generators, Y_trans
from ax.modelbridge.torch import TorchAdapter
from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel
from ax.models.torch.botorch_modular.model import BoTorchGenerator
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.utils.testing.mock import skip_fit_gpytorch_mll_context_manager
from botorch.models import SingleTaskGP
Expand Down Expand Up @@ -133,7 +134,9 @@ def get_surrogate() -> TorchAdapter:
data=obj["data"],
transforms=Cont_X_trans + Y_trans,
)
mb.model.surrogate.model.load_state_dict(obj["state_dict"])
assert_is_instance(mb.model, BoTorchGenerator).surrogate.model.load_state_dict(
obj["state_dict"]
)
return assert_is_instance(mb, TorchAdapter)

name = f"LCBench_Surrogate_{dataset_name}:v1"
Expand Down
7 changes: 5 additions & 2 deletions ax/generation_strategy/tests/test_dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ax.modelbridge.registry import Generators, MBM_X_trans, Mixed_transforms, Y_trans
from ax.modelbridge.transforms.log_y import LogY
from ax.modelbridge.transforms.winsorize import Winsorize
from ax.models.random.sobol import SobolGenerator
from ax.models.winsorization_config import WinsorizationConfig
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
Expand All @@ -35,7 +36,7 @@
run_branin_experiment_with_generation_strategy,
)
from ax.utils.testing.mock import mock_botorch_optimize
from pyre_extensions import none_throws
from pyre_extensions import assert_is_instance, none_throws


class TestDispatchUtils(TestCase):
Expand Down Expand Up @@ -406,7 +407,9 @@ def test_setting_random_seed(self) -> None:
)
sobol.gen(experiment=get_experiment(), n=1)
# First model is actually a bridge, second is the Sobol engine.
self.assertEqual(none_throws(sobol.model).model.seed, 9)
self.assertEqual(
assert_is_instance(none_throws(sobol.model).model, SobolGenerator).seed, 9
)

with self.subTest("warns if use_saasbo is true"):
with self.assertLogs(
Expand Down
18 changes: 12 additions & 6 deletions ax/generation_strategy/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,10 +580,13 @@ def test_sobol_MBM_strategy(self) -> None:
)
ms = none_throws(g._model_state_after_gen).copy()
# Compare the model state to Sobol state.
sobol_model = none_throws(gs.model).model
sobol_model = assert_is_instance(
none_throws(gs.model).model, SobolGenerator
)
self.assertTrue(
np.array_equal(
ms.pop("generated_points"), sobol_model.generated_points
ms.pop("generated_points"),
none_throws(sobol_model.generated_points),
)
)
# Replace expected seed with the one generated in __init__.
Expand Down Expand Up @@ -714,9 +717,9 @@ def test_with_factory_function(self) -> None:
"""Checks that generation strategy works with custom factory functions.
No information about the model should be saved on generator run."""

def get_sobol(search_space: SearchSpace) -> RandomAdapter:
def get_sobol(experiment: Experiment) -> RandomAdapter:
return RandomAdapter(
search_space=search_space,
experiment=experiment,
model=SobolGenerator(),
transforms=Cont_X_trans,
)
Expand Down Expand Up @@ -1551,10 +1554,13 @@ def test_gs_with_generation_nodes(self) -> None:
)
ms = none_throws(g._model_state_after_gen).copy()
# Compare the model state to Sobol state.
sobol_model = none_throws(self.sobol_MBM_GS_nodes.model).model
sobol_model = assert_is_instance(
none_throws(self.sobol_MBM_GS_nodes.model).model, SobolGenerator
)
self.assertTrue(
np.array_equal(
ms.pop("generated_points"), sobol_model.generated_points
ms.pop("generated_points"),
none_throws(sobol_model.generated_points),
)
)
# Replace expected seed with the one generated in __init__.
Expand Down
81 changes: 35 additions & 46 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.cast import Cast
from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters
from ax.models.base import Generator
from ax.models.types import TConfig
from ax.utils.common.logger import get_logger
from botorch.settings import validate_input_scaling
Expand Down Expand Up @@ -90,18 +91,18 @@ class Adapter(ABC): # noqa: B024 -- Adapter doesn't have any abstract methods.
receives appropriate inputs.
Subclasses will implement what is here referred to as the "terminal
transform," which is a transform that changes types of the data and problem
transform", which is a transform that changes types of the data and problem
specification.
"""

def __init__(
self,
search_space: SearchSpace,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
model: Any,
transforms: list[type[Transform]] | None = None,
experiment: Experiment | None = None,
*,
experiment: Experiment,
model: Generator,
search_space: SearchSpace | None = None,
data: Data | None = None,
transforms: list[type[Transform]] | None = None,
transform_configs: dict[str, TConfig] | None = None,
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
Expand All @@ -117,14 +118,22 @@ def __init__(
Applies transforms and fits model.
Args:
experiment: Is used to get arm parameters. Is not mutated.
search_space: Search space for fitting the model. Constraints need
not be the same ones used in gen. RangeParameter bounds are
considered soft and will be expanded to match the range of the
data sent in for fitting, if expand_model_space is True.
data: Ax Data.
model: Interface will be specified in subclass. If model requires
experiment: An ``Experiment`` object representing the setup and the
current state of the experiment, including the search space,
trials and observation data. It is used to extract various
attributes, and is not mutated.
model: A ``Generator`` that is used for generating candidates.
Its interface will be specified in subclasses. If model requires
initialization, that should be done prior to its use here.
search_space: An optional ``SearchSpace`` for fitting the model.
If not provided, `experiment.search_space` is used.
The search space may be modified during ``Adapter.gen``, e.g.,
to try out a different set of parameter bounds or constraints.
The bounds of the ``RangeParameter``s are considered soft and
will be expanded to match the range of the data sent in for fitting,
if `expand_model_space` is True.
data: An optional ``Data`` object, containing mean and SEM observations.
If `None`, extracted using `experiment.lookup_data()`.
transforms: List of uninitialized transform classes. Forward
transforms will be applied in this order, and untransforms in
the reverse order.
Expand All @@ -135,8 +144,8 @@ def __init__(
that arm.
status_quo_features: ObservationFeatures to use as status quo.
Either this or status_quo_name should be specified, not both.
optimization_config: Optimization config defining how to optimize
the model.
optimization_config: An optional ``OptimizationConfig`` defining how to
optimize the model. Defaults to `experiment.optimization_config`.
expand_model_space: If True, expand range parameter bounds in model
space to cover given training data. This will make the modeling
space larger than the search space if training data fall outside
Expand Down Expand Up @@ -179,6 +188,7 @@ def __init__(
self._model_kwargs: dict[str, Any] | None = None
self._bridge_kwargs: dict[str, Any] | None = None
# The space used for optimization.
search_space = search_space or experiment.search_space
self._search_space: SearchSpace = search_space.clone()
# The space used for modeling. Might be larger than the optimization
# space to cover training data.
Expand All @@ -194,13 +204,12 @@ def __init__(
experiment is not None and experiment.immutable_search_space_and_opt_config
)
self._experiment_properties: dict[str, Any] = {}
self._experiment: Experiment | None = experiment
self._experiment: Experiment = experiment

if experiment is not None:
if self._optimization_config is None:
self._optimization_config = experiment.optimization_config
self._arms_by_signature = experiment.arms_by_signature
self._experiment_properties = experiment._properties
if self._optimization_config is None:
self._optimization_config = experiment.optimization_config
self._arms_by_signature = experiment.arms_by_signature
self._experiment_properties = experiment._properties

if self._fit_tracking_metrics is False:
if self._optimization_config is None:
Expand All @@ -212,6 +221,7 @@ def __init__(

# Set training data (in the raw / untransformed space). This also omits
# out-of-design and abandoned observations depending on the corresponding flags.
data = data if data is not None else experiment.lookup_data()
observations_raw = self._prepare_observations(experiment=experiment, data=data)
if expand_model_space:
self._set_model_space(observations=observations_raw)
Expand Down Expand Up @@ -259,11 +269,7 @@ def _fit_if_implemented(
"""
try:
t_fit_start = time.monotonic()
self._fit(
model=self.model,
search_space=search_space,
observations=observations,
)
self._fit(search_space=search_space, observations=observations)
increment = time.monotonic() - t_fit_start + time_so_far
self.fit_time += increment
self.fit_time_since_gen += increment
Expand Down Expand Up @@ -477,7 +483,7 @@ def _set_status_quo(

if status_quo_name is not None:
if status_quo_features is not None:
raise ValueError(
raise UserInputError(
"Specify either status_quo_name or status_quo_features, not both."
)
sq_obs = [
Expand Down Expand Up @@ -596,8 +602,6 @@ def training_in_design(self, training_in_design: list[bool]) -> None:

def _fit(
self,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
model: Any,
search_space: SearchSpace,
observations: list[Observation],
) -> None:
Expand Down Expand Up @@ -736,21 +740,6 @@ def _predict(
f"{self.__class__.__name__} does not implement `_predict`."
)

def update(self, new_data: Data, experiment: Experiment) -> None:
"""Update the model bridge and the underlying model with new data. This
method should be used instead of `fit`, in cases where the underlying
model does not need to be re-fit from scratch, but rather updated.
Note: `update` expects only new data (obtained since the model initialization
or last update) to be passed in, not all data in the experiment.
Args:
new_data: Data from the experiment obtained since the last call to
`update`.
experiment: Experiment, in which this data was obtained.
"""
raise DeprecationWarning("Adapter.update is deprecated. Use `fit` instead.")

def _get_transformed_gen_args(
self,
search_space: SearchSpace,
Expand Down Expand Up @@ -1080,13 +1069,13 @@ def _get_serialized_model_state(self) -> dict[str, Any]:
"""Obtains the state of the underlying model (if using a stateful one)
in a readily JSON-serializable form.
"""
model = none_throws(self.model)
model = self.model
return model.serialize_state(raw_state=model._get_state())

def _deserialize_model_state(
self, serialized_state: dict[str, Any]
) -> dict[str, Any]:
model = none_throws(self.model)
model = self.model
return model.deserialize_state(serialized_state=serialized_state)

def feature_importances(self, metric_name: str) -> dict[str, float]:
Expand Down
56 changes: 45 additions & 11 deletions ax/modelbridge/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# pyre-strict


from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.observation import (
Observation,
ObservationData,
Expand All @@ -28,6 +30,7 @@
extract_outcome_constraints,
validate_transformed_optimization_config,
)
from ax.modelbridge.transforms.base import Transform
from ax.models.discrete_base import DiscreteGenerator
from ax.models.types import TConfig

Expand All @@ -38,25 +41,56 @@
class DiscreteAdapter(Adapter):
"""A model bridge for using models based on discrete parameters.
Requires that all parameters have been transformed to ChoiceParameters.
Requires that all parameters to have been transformed to ChoiceParameters.
"""

# pyre-fixme[13]: Attribute `model` is never initialized.
model: DiscreteGenerator
# pyre-fixme[13]: Attribute `outcomes` is never initialized.
outcomes: list[str]
# pyre-fixme[13]: Attribute `parameters` is never initialized.
parameters: list[str]
# pyre-fixme[13]: Attribute `search_space` is never initialized.
search_space: SearchSpace | None
def __init__(
self,
*,
experiment: Experiment,
model: DiscreteGenerator,
search_space: SearchSpace | None = None,
data: Data | None = None,
transforms: list[type[Transform]] | None = None,
transform_configs: dict[str, TConfig] | None = None,
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
optimization_config: OptimizationConfig | None = None,
expand_model_space: bool = True,
fit_out_of_design: bool = False,
fit_abandoned: bool = False,
fit_tracking_metrics: bool = True,
fit_on_init: bool = True,
fit_only_completed_map_metrics: bool = True,
) -> None:
# These are set in _fit.
self.parameters: list[str] = []
self.outcomes: list[str] = []
super().__init__(
experiment=experiment,
model=model,
search_space=search_space,
data=data,
transforms=transforms,
transform_configs=transform_configs,
status_quo_name=status_quo_name,
status_quo_features=status_quo_features,
optimization_config=optimization_config,
expand_model_space=expand_model_space,
fit_out_of_design=fit_out_of_design,
fit_abandoned=fit_abandoned,
fit_tracking_metrics=fit_tracking_metrics,
fit_on_init=fit_on_init,
fit_only_completed_map_metrics=fit_only_completed_map_metrics,
)
# Re-assing for more precise typing.
self.model: DiscreteGenerator = model

def _fit(
self,
model: DiscreteGenerator,
search_space: SearchSpace,
observations: list[Observation],
) -> None:
self.model = model
# Convert observations to arrays
self.parameters = list(search_space.parameters.keys())
all_metric_names: set[str] = set()
Expand Down
6 changes: 4 additions & 2 deletions ax/modelbridge/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_sobol(
"""
return assert_is_instance(
Generators.SOBOL(
search_space=search_space,
experiment=Experiment(search_space=search_space),
seed=seed,
deduplicate=deduplicate,
init_position=init_position,
Expand All @@ -98,7 +98,9 @@ def get_uniform(
"""
return assert_is_instance(
Generators.UNIFORM(
search_space=search_space, seed=seed, deduplicate=deduplicate
experiment=Experiment(search_space=search_space),
seed=seed,
deduplicate=deduplicate,
),
RandomAdapter,
)
Expand Down
Loading

0 comments on commit c0c9dba

Please sign in to comment.