Skip to content

Commit

Permalink
Merge pull request #365 from mit-ll-responsible-ai/support-optim-requ…
Browse files Browse the repository at this point in the history
…ired

Closes #329 -- adds auto-config support for torch.optim.optimizer.required
  • Loading branch information
rsokl authored Dec 30, 2022
2 parents d66887d + c8b3d0c commit 028fd54
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 2 deletions.
3 changes: 2 additions & 1 deletion docs/source/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ hydra-zen also provides auto-config support for some third-pary libraries:

- `pydantic.dataclasses.dataclass`
- `pydantic.Field`

- `pydantic.Field`
- `torch.optim.optimizer.required` (i.e. the default parameter for `lr` in `Optimizer`)


*********************
Expand Down
1 change: 1 addition & 0 deletions docs/source/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ New Features
- Adds the :class:`~hydra_zen.ZenStore` class (see :pull:`331`)
- Adds `hyda_zen.store`, which is a pre-initialized instance of :class:`~hydra_zen.ZenStore` (see :pull:`331`)
- The option `hydra_convert='object'` is now supported by all of hydra-zen's config-creation functions. So that an instantiated structured config can be converted to an instance of its backing dataclass. This feature was added by `Hydra 1.3.0 <https://github.com/facebookresearch/hydra/issues/1719>`_.
- Adds auto-config support for `torch.optim.optimizer.required` so that the common pattern `builds(<torch_optimizer_type>, populate_full_signature=True, zen_partial=True)` works and exposes `lr` as a required configurable parameter. Thanks to @addisonklinke for requesting this in :issue:`257`.

Improvements
------------
Expand Down
17 changes: 17 additions & 0 deletions src/hydra_zen/structured_configs/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,19 @@ def _is_jax_compiled_func(value: Any) -> bool: # pragma: no cover
return False


def _is_torch_optim_required(value: Any) -> bool: # pragma: no cover
torch_optim_optimizer = sys.modules.get("torch.optim.optimizer")

if torch_optim_optimizer is None:
return False

try:
required = getattr(torch_optim_optimizer, "required")
return value is required
except AttributeError:
return False


def _check_for_dynamically_defined_dataclass_type(target_path: str, value: Any) -> None:
if target_path.startswith("types."):
raise HydraZenUnsupportedPrimitiveError(
Expand Down Expand Up @@ -703,6 +716,10 @@ def sanitized_default_value(
else: # pragma: no cover
del _v

# support for torch objects
if _is_torch_optim_required(value): # pragma: no cover
return MISSING

# `value` could no be converted to Hydra-compatible representation.
# Raise error
if field_name:
Expand Down
9 changes: 8 additions & 1 deletion tests/test_third_party/test_against_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch as tr
from hypothesis import assume, given
from omegaconf import OmegaConf
from torch.optim import Adam, AdamW
from torch.optim import SGD, Adam, AdamW
from torch.utils.data import DataLoader, Dataset

from hydra_zen import builds, hydrated_dataclass, instantiate, just, to_yaml
Expand Down Expand Up @@ -145,3 +145,10 @@ def test_dataloader_sig_doesnt_have_self():
assert "self" not in {f.name for f in _fields}
assert len(_fields) > 2
assert list(instantiate(Conf)) == [tr.tensor(float(i)) for i in range(5)]


def test_auto_config_support_for_optim_required():
Opt = instantiate(
builds(SGD, populate_full_signature=True, zen_partial=True)(lr=1.0)
)
assert isinstance(Opt(tr.nn.Linear(1, 1).parameters()), SGD)

0 comments on commit 028fd54

Please sign in to comment.