-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from qthequartermasterman/optimizers
feat: ✨ add strategies for torch optimizers
- Loading branch information
Showing
5 changed files
with
223 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
"""Strategies for generating torch optimizers.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Sequence, Final, Callable, Iterator | ||
from typing_extensions import TypeAlias | ||
|
||
import hypothesis | ||
import torch.optim | ||
import inspect | ||
|
||
from hypothesis import strategies as st | ||
from hypothesis_torch import inspection_util | ||
|
||
OptimizerConstructorWithOnlyParameters: TypeAlias = Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer] | ||
|
||
OPTIMIZERS: Final[tuple[type[torch.optim.Optimizer], ...]] = tuple( | ||
optimizer | ||
for optimizer in inspection_util.get_all_subclasses(torch.optim.Optimizer) | ||
if optimizer is not torch.optim.Optimizer and "NewCls" not in optimizer.__name__ | ||
) | ||
|
||
_ZERO_TO_ONE_FLOATS: Final[st.SearchStrategy[float]] = st.floats( | ||
min_value=0.0, max_value=1.0, exclude_max=True, exclude_min=True | ||
) | ||
|
||
|
||
@st.composite | ||
def betas(draw: st.DrawFn) -> tuple[float, float]: | ||
"""Strategy for generating beta1 and beta2 values for optimizers. | ||
Args: | ||
draw: The draw function provided by `hypothesis`. | ||
Returns: | ||
A tuple of beta1 and beta2 values. | ||
""" | ||
beta1 = draw(st.floats(min_value=0.0, max_value=0.95, exclude_max=True, exclude_min=True)) | ||
beta2 = draw(st.floats(min_value=beta1, max_value=0.999, exclude_max=True, exclude_min=True)) | ||
return beta1, beta2 | ||
|
||
|
||
HYPERPARAM_OVERRIDE_STRATEGIES: Final[dict[str, st.SearchStrategy]] = { | ||
"lr": _ZERO_TO_ONE_FLOATS, | ||
"weight_decay": _ZERO_TO_ONE_FLOATS, | ||
"momentum": _ZERO_TO_ONE_FLOATS, | ||
"betas": betas(), | ||
"lr_decay": _ZERO_TO_ONE_FLOATS, | ||
"eps": _ZERO_TO_ONE_FLOATS, | ||
"centered": st.booleans(), | ||
"rho": _ZERO_TO_ONE_FLOATS, | ||
"momentum_decay": _ZERO_TO_ONE_FLOATS, | ||
"etas": st.tuples( | ||
st.floats(min_value=0.0, max_value=1.0, exclude_min=True, exclude_max=True), | ||
st.floats(min_value=1.0, max_value=2.0, exclude_min=True, exclude_max=True), | ||
), | ||
"dampening": _ZERO_TO_ONE_FLOATS, | ||
"nesterov": st.booleans(), | ||
"initial_accumulator_value": _ZERO_TO_ONE_FLOATS, | ||
} | ||
|
||
|
||
def optimizer_type_strategy( | ||
allowed_optimizer_types: Sequence[type[torch.optim.Optimizer]] | None = None, | ||
) -> st.SearchStrategy[type[torch.optim.Optimizer]]: | ||
"""Strategy for generating torch optimizers. | ||
Args: | ||
allowed_optimizer_types: A sequence of optimizers to sample from. If None, all available optimizers are sampled. | ||
Returns: | ||
A strategy for generating torch optimizers. | ||
""" | ||
if allowed_optimizer_types is None: | ||
allowed_optimizer_types = OPTIMIZERS | ||
return st.sampled_from(allowed_optimizer_types) | ||
|
||
|
||
@st.composite | ||
def optimizer_strategy( | ||
draw: st.DrawFn, | ||
optimizer_type: type[torch.optim.Optimizer] | st.SearchStrategy[type[torch.optim.Optimizer]] = None, | ||
**kwargs, | ||
) -> st.SearchStrategy[OptimizerConstructorWithOnlyParameters]: | ||
"""Strategy for generating torch optimizers. | ||
Args: | ||
draw: The draw function provided by `hypothesis`. | ||
optimizer_type: The optimizer type or a strategy for generating optimizer types. | ||
kwargs: Keyword arguments to pass to the optimizer constructor. If a keyword argument is a strategy, it will be | ||
drawn from. | ||
""" | ||
if optimizer_type is None: | ||
optimizer_type = optimizer_type_strategy() | ||
if isinstance(optimizer_type, st.SearchStrategy): | ||
optimizer_type = draw(optimizer_type) | ||
|
||
sig = inspection_util.infer_signature_annotations(optimizer_type.__init__) | ||
for param in sig.parameters.values(): | ||
if param.name in kwargs and isinstance(kwargs[param.name], st.SearchStrategy): | ||
kwargs[param.name] = draw(kwargs[param.name]) | ||
elif param.annotation is inspect.Parameter.empty: | ||
continue | ||
elif param.name in HYPERPARAM_OVERRIDE_STRATEGIES: | ||
kwargs[param.name] = draw(HYPERPARAM_OVERRIDE_STRATEGIES[param.name]) | ||
elif param.annotation in (float, int): | ||
kwargs[param.name] = draw(_ZERO_TO_ONE_FLOATS) | ||
else: | ||
kwargs[param.name] = draw(st.from_type(param.annotation)) | ||
if "nesterov" in kwargs and kwargs["nesterov"] and "momentum" in kwargs: | ||
kwargs["dampening"] = 0 | ||
kwargs.pop("self", None) # Remove self if a type was inferred | ||
kwargs.pop("params", None) # Remove params if a type was inferred | ||
|
||
hypothesis.note(f"Chosen optimizer type: {optimizer_type}") | ||
hypothesis.note(f"Chosen optimizer hyperparameters: {kwargs}") | ||
|
||
def optimizer_factory(params: Sequence[torch.nn.Parameter]) -> torch.optim.Optimizer: | ||
return optimizer_type(params, **kwargs) | ||
|
||
return optimizer_factory |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
"""Tests for the `hypothesis_torch.optim` module.""" | ||
|
||
import unittest | ||
|
||
import hypothesis_torch | ||
import hypothesis | ||
from hypothesis import strategies as st | ||
|
||
import torch | ||
|
||
|
||
class TestOptimizerTypeStrategy(unittest.TestCase): | ||
"""Tests for the `optimizer_type_strategy` function.""" | ||
|
||
@hypothesis.given(optimizer_type=hypothesis_torch.optimizer_type_strategy()) | ||
def test_optimizer_type_strategy(self, optimizer_type: type[torch.optim.Optimizer]) -> None: | ||
"""Test that `optimizer_type_strategy` generates valid optimizer types.""" | ||
self.assertTrue(issubclass(optimizer_type, torch.optim.Optimizer)) | ||
self.assertNotEqual(optimizer_type, torch.optim.Optimizer) | ||
self.assertNotIn("NewCls", optimizer_type.__name__) | ||
|
||
@hypothesis.given( | ||
optimizer_type=hypothesis_torch.optimizer_type_strategy(allowed_optimizer_types=[torch.optim.Adam]) | ||
) | ||
def test_optimizer_type_strategy_allowed_optimizer_types(self, optimizer_type: type[torch.optim.Optimizer]) -> None: | ||
"""Test that `optimizer_type_strategy` generates optimizer types when specifying `allowed_optimizer_types`.""" | ||
self.assertEqual(optimizer_type, torch.optim.Adam) | ||
|
||
|
||
class TestOptimizerStrategy(unittest.TestCase): | ||
"""Tests for the `optimizer_strategy` function.""" | ||
|
||
@hypothesis.settings(deadline=None) | ||
@hypothesis.given( | ||
optimizer_constructor=hypothesis_torch.optimizer_strategy(), | ||
module=hypothesis_torch.linear_network_strategy( | ||
input_shape=(1, 1), | ||
output_shape=(1, 1), | ||
activation_layer=torch.nn.ReLU(), | ||
hidden_layer_size=st.integers(min_value=1, max_value=10), | ||
num_hidden_layers=st.integers(min_value=1, max_value=10), | ||
device=hypothesis_torch.device_strategy(), | ||
), | ||
) | ||
def test_optimizer_strategy( | ||
self, optimizer_constructor: hypothesis_torch.OptimizerConstructorWithOnlyParameters, module: torch.nn.Module | ||
) -> None: | ||
"""Test that `optimizer_strategy` generates valid optimizers.""" | ||
optimizer = optimizer_constructor(module.parameters()) | ||
self.assertIsInstance(optimizer, torch.optim.Optimizer) |
fbefa06
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Coverage Report
fbefa06
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Coverage Report
fbefa06
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Coverage Report
fbefa06
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Coverage Report