Skip to content

Commit

Permalink
Changes to Adapter constructors, require experiment (#3415)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3415

The motivation for this change is to bring us closer to relying on the `experiment` as the source of truth about the experiment state & attributes within the modeling layer. We currently support many inputs that are extracted from the `experiment`, just to be passed in alongside it. By making `experiment` a required input, we open the possibility of removing these extra inputs and extracting them directly from `experiment` where they're needed.

Makes the following changes to Adapter constructors:
- Requires keyword-only arguments. Positional inputs are no-longer supported.
- Makes `experiment` required and `search_space` optional.
- Re-orders inputs for consistency across sub-classes.

In addition:
- Removes `model` input to `Adapter._fit`. This is a private method that is only called through `fit_if_implemented` (with `self.model`). Accepting multiple inputs for the same argument only makes the code harder to reason about.
- Removes class level attributes, some of which weren't initialized in `__init__`, leading to pyre complaints. All attributes are now initialized in `__init__`. This also eliminates misleading "optional" type hints with `None` default for `model`, which is never `None` in practice.
- Removes `Adapter.update`, which has been deprecated for quite some time.
- Initializing a Generator from registry with only `search_space` is being deprecated. It is temporarily supported using a dummy experiment for random & discrete adapters, which previously did not require an `experiment`.

Differential Revision: D70103442
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 25, 2025
1 parent c64d122 commit adfa3d8
Show file tree
Hide file tree
Showing 21 changed files with 492 additions and 651 deletions.
81 changes: 35 additions & 46 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.cast import Cast
from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters
from ax.models.base import Generator
from ax.models.types import TConfig
from ax.utils.common.logger import get_logger
from botorch.settings import validate_input_scaling
Expand Down Expand Up @@ -90,18 +91,18 @@ class Adapter(ABC): # noqa: B024 -- Adapter doesn't have any abstract methods.
receives appropriate inputs.
Subclasses will implement what is here referred to as the "terminal
transform," which is a transform that changes types of the data and problem
transform", which is a transform that changes types of the data and problem
specification.
"""

def __init__(
self,
search_space: SearchSpace,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
model: Any,
transforms: list[type[Transform]] | None = None,
experiment: Experiment | None = None,
*,
experiment: Experiment,
model: Generator,
search_space: SearchSpace | None = None,
data: Data | None = None,
transforms: list[type[Transform]] | None = None,
transform_configs: dict[str, TConfig] | None = None,
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
Expand All @@ -117,14 +118,22 @@ def __init__(
Applies transforms and fits model.
Args:
experiment: Is used to get arm parameters. Is not mutated.
search_space: Search space for fitting the model. Constraints need
not be the same ones used in gen. RangeParameter bounds are
considered soft and will be expanded to match the range of the
data sent in for fitting, if expand_model_space is True.
data: Ax Data.
model: Interface will be specified in subclass. If model requires
experiment: An ``Experiment`` object representing the setup and the
current state of the experiment, including the search space,
trials and observation data. It is used to extract various
attributes, and is not mutated.
model: A ``Generator`` that is used for generating candidates.
Its interface will be specified in subclasses. If model requires
initialization, that should be done prior to its use here.
search_space: An optional ``SearchSpace`` for fitting the model.
If not provided, `experiment.search_space` is used.
The search space may be modified during ``Adapter.gen``, e.g.,
to try out a different set of parameter bounds or constraints.
The bounds of the ``RangeParameter``s are considered soft and
will be expanded to match the range of the data sent in for fitting,
if `expand_model_space` is True.
data: An optional ``Data`` object, containing mean and SEM observations.
If `None`, extracted using `experiment.lookup_data()`.
transforms: List of uninitialized transform classes. Forward
transforms will be applied in this order, and untransforms in
the reverse order.
Expand All @@ -135,8 +144,8 @@ def __init__(
that arm.
status_quo_features: ObservationFeatures to use as status quo.
Either this or status_quo_name should be specified, not both.
optimization_config: Optimization config defining how to optimize
the model.
optimization_config: An optional ``OptimizationConfig`` defining how to
optimize the model. Defaults to `experiment.optimization_config`.
expand_model_space: If True, expand range parameter bounds in model
space to cover given training data. This will make the modeling
space larger than the search space if training data fall outside
Expand Down Expand Up @@ -179,6 +188,7 @@ def __init__(
self._model_kwargs: dict[str, Any] | None = None
self._bridge_kwargs: dict[str, Any] | None = None
# The space used for optimization.
search_space = search_space or experiment.search_space
self._search_space: SearchSpace = search_space.clone()
# The space used for modeling. Might be larger than the optimization
# space to cover training data.
Expand All @@ -194,13 +204,12 @@ def __init__(
experiment is not None and experiment.immutable_search_space_and_opt_config
)
self._experiment_properties: dict[str, Any] = {}
self._experiment: Experiment | None = experiment
self._experiment: Experiment = experiment

if experiment is not None:
if self._optimization_config is None:
self._optimization_config = experiment.optimization_config
self._arms_by_signature = experiment.arms_by_signature
self._experiment_properties = experiment._properties
if self._optimization_config is None:
self._optimization_config = experiment.optimization_config
self._arms_by_signature = experiment.arms_by_signature
self._experiment_properties = experiment._properties

if self._fit_tracking_metrics is False:
if self._optimization_config is None:
Expand All @@ -212,6 +221,7 @@ def __init__(

# Set training data (in the raw / untransformed space). This also omits
# out-of-design and abandoned observations depending on the corresponding flags.
data = data if data is not None else experiment.lookup_data()
observations_raw = self._prepare_observations(experiment=experiment, data=data)
if expand_model_space:
self._set_model_space(observations=observations_raw)
Expand Down Expand Up @@ -259,11 +269,7 @@ def _fit_if_implemented(
"""
try:
t_fit_start = time.monotonic()
self._fit(
model=self.model,
search_space=search_space,
observations=observations,
)
self._fit(search_space=search_space, observations=observations)
increment = time.monotonic() - t_fit_start + time_so_far
self.fit_time += increment
self.fit_time_since_gen += increment
Expand Down Expand Up @@ -477,7 +483,7 @@ def _set_status_quo(

if status_quo_name is not None:
if status_quo_features is not None:
raise ValueError(
raise UserInputError(
"Specify either status_quo_name or status_quo_features, not both."
)
sq_obs = [
Expand Down Expand Up @@ -596,8 +602,6 @@ def training_in_design(self, training_in_design: list[bool]) -> None:

def _fit(
self,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
model: Any,
search_space: SearchSpace,
observations: list[Observation],
) -> None:
Expand Down Expand Up @@ -736,21 +740,6 @@ def _predict(
f"{self.__class__.__name__} does not implement `_predict`."
)

def update(self, new_data: Data, experiment: Experiment) -> None:
"""Update the model bridge and the underlying model with new data. This
method should be used instead of `fit`, in cases where the underlying
model does not need to be re-fit from scratch, but rather updated.
Note: `update` expects only new data (obtained since the model initialization
or last update) to be passed in, not all data in the experiment.
Args:
new_data: Data from the experiment obtained since the last call to
`update`.
experiment: Experiment, in which this data was obtained.
"""
raise DeprecationWarning("Adapter.update is deprecated. Use `fit` instead.")

def _get_transformed_gen_args(
self,
search_space: SearchSpace,
Expand Down Expand Up @@ -1080,13 +1069,13 @@ def _get_serialized_model_state(self) -> dict[str, Any]:
"""Obtains the state of the underlying model (if using a stateful one)
in a readily JSON-serializable form.
"""
model = none_throws(self.model)
model = self.model
return model.serialize_state(raw_state=model._get_state())

def _deserialize_model_state(
self, serialized_state: dict[str, Any]
) -> dict[str, Any]:
model = none_throws(self.model)
model = self.model
return model.deserialize_state(serialized_state=serialized_state)

def feature_importances(self, metric_name: str) -> dict[str, float]:
Expand Down
56 changes: 45 additions & 11 deletions ax/modelbridge/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# pyre-strict


from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.observation import (
Observation,
ObservationData,
Expand All @@ -28,6 +30,7 @@
extract_outcome_constraints,
validate_transformed_optimization_config,
)
from ax.modelbridge.transforms.base import Transform
from ax.models.discrete_base import DiscreteGenerator
from ax.models.types import TConfig

Expand All @@ -38,25 +41,56 @@
class DiscreteAdapter(Adapter):
"""A model bridge for using models based on discrete parameters.
Requires that all parameters have been transformed to ChoiceParameters.
Requires that all parameters to have been transformed to ChoiceParameters.
"""

# pyre-fixme[13]: Attribute `model` is never initialized.
model: DiscreteGenerator
# pyre-fixme[13]: Attribute `outcomes` is never initialized.
outcomes: list[str]
# pyre-fixme[13]: Attribute `parameters` is never initialized.
parameters: list[str]
# pyre-fixme[13]: Attribute `search_space` is never initialized.
search_space: SearchSpace | None
def __init__(
self,
*,
experiment: Experiment,
model: DiscreteGenerator,
search_space: SearchSpace | None = None,
data: Data | None = None,
transforms: list[type[Transform]] | None = None,
transform_configs: dict[str, TConfig] | None = None,
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
optimization_config: OptimizationConfig | None = None,
expand_model_space: bool = True,
fit_out_of_design: bool = False,
fit_abandoned: bool = False,
fit_tracking_metrics: bool = True,
fit_on_init: bool = True,
fit_only_completed_map_metrics: bool = True,
) -> None:
# These are set in _fit.
self.parameters: list[str] = []
self.outcomes: list[str] = []
super().__init__(
experiment=experiment,
model=model,
search_space=search_space,
data=data,
transforms=transforms,
transform_configs=transform_configs,
status_quo_name=status_quo_name,
status_quo_features=status_quo_features,
optimization_config=optimization_config,
expand_model_space=expand_model_space,
fit_out_of_design=fit_out_of_design,
fit_abandoned=fit_abandoned,
fit_tracking_metrics=fit_tracking_metrics,
fit_on_init=fit_on_init,
fit_only_completed_map_metrics=fit_only_completed_map_metrics,
)
# Re-assing for more precise typing.
self.model: DiscreteGenerator = model

def _fit(
self,
model: DiscreteGenerator,
search_space: SearchSpace,
observations: list[Observation],
) -> None:
self.model = model
# Convert observations to arrays
self.parameters = list(search_space.parameters.keys())
all_metric_names: set[str] = set()
Expand Down
6 changes: 4 additions & 2 deletions ax/modelbridge/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_sobol(
"""
return assert_is_instance(
Generators.SOBOL(
search_space=search_space,
experiment=Experiment(search_space=search_space),
seed=seed,
deduplicate=deduplicate,
init_position=init_position,
Expand All @@ -98,7 +98,9 @@ def get_uniform(
"""
return assert_is_instance(
Generators.UNIFORM(
search_space=search_space, seed=seed, deduplicate=deduplicate
experiment=Experiment(search_space=search_space),
seed=seed,
deduplicate=deduplicate,
),
RandomAdapter,
)
Expand Down
53 changes: 10 additions & 43 deletions ax/modelbridge/map_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import numpy as np
import numpy.typing as npt

import torch
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
Expand Down Expand Up @@ -47,7 +46,7 @@
DEFAULT_TARGET_MAP_VALUES = {"steps": 1.0}


class MapTorchAdapter(TorchAdapter):
class MapTorchAdapter(TorchAdapter): # TODO
"""A model bridge for using torch-based models that fit on MapData. Most
of the `TorchAdapter` functionality is retained, except that this
class should be used in the case where `model` makes use of map_key values.
Expand All @@ -57,71 +56,41 @@ class should be used in the case where `model` makes use of map_key values.

def __init__(
self,
*,
experiment: Experiment,
search_space: SearchSpace,
data: Data,
model: TorchGenerator,
transforms: list[type[Transform]],
search_space: SearchSpace | None = None,
data: Data | None = None,
transforms: list[type[Transform]] | None = None,
transform_configs: dict[str, TConfig] | None = None,
torch_device: torch.device | None = None,
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
optimization_config: OptimizationConfig | None = None,
fit_out_of_design: bool = False,
fit_on_init: bool = True,
fit_abandoned: bool = False,
default_model_gen_options: TConfig | None = None,
torch_device: torch.device | None = None,
map_data_limit_rows_per_metric: int | None = None,
map_data_limit_rows_per_group: int | None = None,
) -> None:
"""
Applies transforms and fits model.
"""In addition to common arguments documented in the ``Adapter`` and
``TorchAdapter`` classes, ``MapTorchAdapter`` accepts the following arguments.
Args:
experiment: Is used to get arm parameters. Is not mutated.
search_space: Search space for fitting the model. Constraints need
not be the same ones used in gen.
data: Ax Data.
model: Interface will be specified in subclass. If model requires
initialization, that should be done prior to its use here.
transforms: List of uninitialized transform classes. Forward
transforms will be applied in this order, and untransforms in
the reverse order.
transform_configs: A dictionary from transform name to the
transform config dictionary.
torch_device: Torch device.
status_quo_name: Name of the status quo arm. Can only be used if
Data has a single set of ObservationFeatures corresponding to
that arm.
status_quo_features: ObservationFeatures to use as status quo.
Either this or status_quo_name should be specified, not both.
optimization_config: Optimization config defining how to optimize
the model.
fit_out_of_design: If specified, all training data is returned.
Otherwise, only in design points are returned.
fit_on_init: Whether to fit the model on initialization. This can
be used to skip model fitting when a fitted model is not needed.
To fit the model afterwards, use `_process_and_transform_data`
to get the transformed inputs and call `_fit_if_implemented` with
the transformed inputs.
fit_abandoned: Whether data for abandoned arms or trials should be
included in model training data. If ``False``, only
non-abandoned points are returned.
default_model_gen_options: Options passed down to `model.gen(...)`.
map_data_limit_rows_per_metric: Subsample the map data so that the
total number of rows per metric is limited by this value.
map_data_limit_rows_per_group: Subsample the map data so that the
number of rows in the `map_key` column for each (arm, metric)
is limited by this value.
"""

data = data or experiment.lookup_data()
if not isinstance(data, MapData):
raise ValueError("`MapTorchAdapter expects `MapData` instead of `Data`.")

if any(isinstance(t, BatchTrial) for t in experiment.trials.values()):
raise ValueError("MapTorchAdapter does not support batch trials.")
# pyre-fixme[4]: Attribute must be annotated.
self._map_key_features = data.map_keys
self._map_key_features: list[str] = data.map_keys
self._map_data_limit_rows_per_metric = map_data_limit_rows_per_metric
self._map_data_limit_rows_per_group = map_data_limit_rows_per_group

Expand Down Expand Up @@ -188,7 +157,6 @@ def _predict(

def _fit(
self,
model: TorchGenerator,
search_space: SearchSpace,
observations: list[Observation],
parameters: list[str] | None = None,
Expand All @@ -201,7 +169,6 @@ def _fit(
if parameters is None:
parameters = self.parameters_with_map_keys
super()._fit(
model=model,
search_space=search_space,
observations=observations,
parameters=parameters,
Expand Down
Loading

0 comments on commit adfa3d8

Please sign in to comment.