From 5cf6181ef6617eda19eb0b51658f835280421c63 Mon Sep 17 00:00:00 2001 From: qthequartermasterman Date: Wed, 24 Apr 2024 23:10:23 -0500 Subject: [PATCH] feat: :sparkles: add strategies for torch optimizer types and torch optimizer instances NOTE: The strategy for torch optimizers actually generates an "alternate constructor" for a torch optimizer that takes in only a torch module's parameters. The strategy will "pre-fill" all of the hyperparameters. If these hyparameters should be overridden, they can be specified as kwargs in the strategy. --- hypothesis_torch/__init__.py | 1 + hypothesis_torch/inspection_util.py | 23 ++++++ hypothesis_torch/module.py | 71 +++++++--------- hypothesis_torch/optim.py | 121 ++++++++++++++++++++++++++++ tests/unit/test_optim.py | 50 ++++++++++++ 5 files changed, 223 insertions(+), 43 deletions(-) create mode 100644 hypothesis_torch/optim.py create mode 100644 tests/unit/test_optim.py diff --git a/hypothesis_torch/__init__.py b/hypothesis_torch/__init__.py index 5e5cc40..bc1e65d 100644 --- a/hypothesis_torch/__init__.py +++ b/hypothesis_torch/__init__.py @@ -28,6 +28,7 @@ from hypothesis_torch.layout import layout_strategy from hypothesis_torch.memory_format import memory_format_strategy from hypothesis_torch.module import linear_network_strategy, same_shape_activation_strategy +from hypothesis_torch.optim import optimizer_strategy, optimizer_type_strategy, OptimizerConstructorWithOnlyParameters from hypothesis_torch.register_random_torch_state import TORCH_RANDOM_WRAPPER from hypothesis_torch.tensor import tensor_strategy diff --git a/hypothesis_torch/inspection_util.py b/hypothesis_torch/inspection_util.py index 7a711d7..b2419aa 100644 --- a/hypothesis_torch/inspection_util.py +++ b/hypothesis_torch/inspection_util.py @@ -4,6 +4,9 @@ import inspect from typing import Callable, TypeVar +from hypothesis import strategies as st + +import torch T = TypeVar("T") @@ -59,3 +62,23 @@ def get_all_subclasses(cls: type[T]) -> list[type[T]]: all_subclasses.append(subclass) all_subclasses.extend(get_all_subclasses(subclass)) return all_subclasses + + +@st.composite +def signature_to_strategy(draw: st.DrawFn, constructor: type[T], *args, **kwargs) -> T: + """Strategy for generating instances of a class by drawing values for its constructor. + + Args: + draw: The draw function provided by `hypothesis`. + constructor: The class to generate an instance of. + args: Positional arguments to pass to the constructor. If an argument is a strategy, it will be drawn from. + kwargs: Keyword arguments to pass to the constructor. If a keyword argument is a strategy, it will be drawn + from. + + Returns: + An instance of the class. + + """ + args_drawn = [draw(strategy) for strategy in args] + kwargs_drawn = {k: draw(strategy) for k, strategy in kwargs.items()} + return constructor(*args_drawn, **kwargs_drawn) diff --git a/hypothesis_torch/module.py b/hypothesis_torch/module.py index aa5d206..b6da174 100644 --- a/hypothesis_torch/module.py +++ b/hypothesis_torch/module.py @@ -4,6 +4,7 @@ from typing import TypeVar, Sequence, Mapping +from hypothesis_torch import inspection_util import torch from hypothesis import strategies as st from torch import nn @@ -60,26 +61,6 @@ def rrelu_strategy(draw: st.DrawFn) -> nn.RReLU: return nn.RReLU(lower, upper, inplace) -@st.composite -def signature_to_strategy(draw: st.DrawFn, constructor: type[T], *args, **kwargs) -> T: - """Strategy for generating instances of a class by drawing values for its constructor. - - Args: - draw: The draw function provided by `hypothesis`. - constructor: The class to generate an instance of. - args: Positional arguments to pass to the constructor. If an argument is a strategy, it will be drawn from. - kwargs: Keyword arguments to pass to the constructor. If a keyword argument is a strategy, it will be drawn - from. - - Returns: - An instance of the class. - - """ - args_drawn = [draw(strategy) for strategy in args] - kwargs_drawn = {k: draw(strategy) for k, strategy in kwargs.items()} - return constructor(*args_drawn, **kwargs_drawn) - - @st.composite def hard_tanh_strategy(draw: st.DrawFn) -> nn.Hardtanh: """Strategy for generating instances of `nn.Hardtanh` by drawing values for its constructor. @@ -98,38 +79,42 @@ def hard_tanh_strategy(draw: st.DrawFn) -> nn.Hardtanh: activation_strategies: dict[type[nn.Module], st.SearchStrategy[nn.Module]] = { - nn.Identity: signature_to_strategy(nn.Identity), - nn.ELU: signature_to_strategy(nn.ELU, alpha=SENSIBLE_FLOATS, inplace=st.booleans()), - nn.Hardshrink: signature_to_strategy(nn.Hardshrink, lambd=SENSIBLE_FLOATS), - nn.Hardsigmoid: signature_to_strategy(nn.Hardsigmoid, inplace=st.booleans()), + nn.Identity: inspection_util.signature_to_strategy(nn.Identity), + nn.ELU: inspection_util.signature_to_strategy(nn.ELU, alpha=SENSIBLE_FLOATS, inplace=st.booleans()), + nn.Hardshrink: inspection_util.signature_to_strategy(nn.Hardshrink, lambd=SENSIBLE_FLOATS), + nn.Hardsigmoid: inspection_util.signature_to_strategy(nn.Hardsigmoid, inplace=st.booleans()), nn.Hardtanh: hard_tanh_strategy(), - nn.Hardswish: signature_to_strategy(nn.Hardswish, inplace=st.booleans()), - nn.LeakyReLU: signature_to_strategy(nn.LeakyReLU, negative_slope=SENSIBLE_FLOATS, inplace=st.booleans()), - nn.LogSigmoid: signature_to_strategy(nn.LogSigmoid), + nn.Hardswish: inspection_util.signature_to_strategy(nn.Hardswish, inplace=st.booleans()), + nn.LeakyReLU: inspection_util.signature_to_strategy( + nn.LeakyReLU, negative_slope=SENSIBLE_FLOATS, inplace=st.booleans() + ), + nn.LogSigmoid: inspection_util.signature_to_strategy(nn.LogSigmoid), # TODO: nn.MultiheadAttention, although in the `Non-linear activations` section, does not have the same shape # inside and outside # TODO: nn.PReLU(num_parameters=1, init=0.25, device=None, dtype=None) # TODO: PReLU might depend on the input shape # TODO: num_parameters (int) – number of a to learn. Although it takes an int as input, there is only two # values are legitimate: 1, or the number of channels at input. Default: 1 - nn.PReLU: signature_to_strategy(nn.PReLU, num_parameters=st.just(1), init=SENSIBLE_FLOATS), - nn.ReLU: signature_to_strategy(nn.ReLU, inplace=st.booleans()), - nn.ReLU6: signature_to_strategy(nn.ReLU6, inplace=st.booleans()), + nn.PReLU: inspection_util.signature_to_strategy(nn.PReLU, num_parameters=st.just(1), init=SENSIBLE_FLOATS), + nn.ReLU: inspection_util.signature_to_strategy(nn.ReLU, inplace=st.booleans()), + nn.ReLU6: inspection_util.signature_to_strategy(nn.ReLU6, inplace=st.booleans()), nn.RReLU: rrelu_strategy(), - nn.SELU: signature_to_strategy(nn.SELU, inplace=st.booleans()), - nn.CELU: signature_to_strategy( + nn.SELU: inspection_util.signature_to_strategy(nn.SELU, inplace=st.booleans()), + nn.CELU: inspection_util.signature_to_strategy( nn.CELU, alpha=SENSIBLE_FLOATS.filter(lambda x: abs(x) > 1e-5), inplace=st.booleans() ), - nn.GELU: signature_to_strategy(nn.GELU, approximate=st.sampled_from(["none", "tanh"])), - nn.Sigmoid: signature_to_strategy(nn.Sigmoid), - nn.SiLU: signature_to_strategy(nn.SiLU, inplace=st.booleans()), - nn.Mish: signature_to_strategy(nn.Mish, inplace=st.booleans()), - nn.Softplus: signature_to_strategy(nn.Softplus, beta=SENSIBLE_FLOATS, threshold=SENSIBLE_POSITIVE_FLOATS), - nn.Softshrink: signature_to_strategy(nn.Softshrink, lambd=SENSIBLE_POSITIVE_FLOATS), - nn.Softsign: signature_to_strategy(nn.Softsign), - nn.Tanh: signature_to_strategy(nn.Tanh), - nn.Tanhshrink: signature_to_strategy(nn.Tanhshrink), - nn.Threshold: signature_to_strategy( + nn.GELU: inspection_util.signature_to_strategy(nn.GELU, approximate=st.sampled_from(["none", "tanh"])), + nn.Sigmoid: inspection_util.signature_to_strategy(nn.Sigmoid), + nn.SiLU: inspection_util.signature_to_strategy(nn.SiLU, inplace=st.booleans()), + nn.Mish: inspection_util.signature_to_strategy(nn.Mish, inplace=st.booleans()), + nn.Softplus: inspection_util.signature_to_strategy( + nn.Softplus, beta=SENSIBLE_FLOATS, threshold=SENSIBLE_POSITIVE_FLOATS + ), + nn.Softshrink: inspection_util.signature_to_strategy(nn.Softshrink, lambd=SENSIBLE_POSITIVE_FLOATS), + nn.Softsign: inspection_util.signature_to_strategy(nn.Softsign), + nn.Tanh: inspection_util.signature_to_strategy(nn.Tanh), + nn.Tanhshrink: inspection_util.signature_to_strategy(nn.Tanhshrink), + nn.Threshold: inspection_util.signature_to_strategy( nn.Threshold, threshold=SENSIBLE_FLOATS, value=SENSIBLE_FLOATS, inplace=st.booleans() ), # TODO: nn.GLU depends on the input shape @@ -206,7 +191,7 @@ def linear_network_strategy( if isinstance(device, st.SearchStrategy): device = draw(device) - if isinstance(activation_layer, nn.Module): + if not isinstance(activation_layer, st.SearchStrategy): activation_layer_strategy = st.just(activation_layer) else: activation_layer_strategy = activation_layer diff --git a/hypothesis_torch/optim.py b/hypothesis_torch/optim.py new file mode 100644 index 0000000..6b83cd1 --- /dev/null +++ b/hypothesis_torch/optim.py @@ -0,0 +1,121 @@ +"""Strategies for generating torch optimizers.""" + +from __future__ import annotations + +from typing import Sequence, Final, Callable, Iterator +from typing_extensions import TypeAlias + +import hypothesis +import torch.optim +import inspect + +from hypothesis import strategies as st +from hypothesis_torch import inspection_util + +OptimizerConstructorWithOnlyParameters: TypeAlias = Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer] + +OPTIMIZERS: Final[tuple[type[torch.optim.Optimizer], ...]] = tuple( + optimizer + for optimizer in inspection_util.get_all_subclasses(torch.optim.Optimizer) + if optimizer is not torch.optim.Optimizer and "NewCls" not in optimizer.__name__ +) + +_ZERO_TO_ONE_FLOATS: Final[st.SearchStrategy[float]] = st.floats( + min_value=0.0, max_value=1.0, exclude_max=True, exclude_min=True +) + + +@st.composite +def betas(draw: st.DrawFn) -> tuple[float, float]: + """Strategy for generating beta1 and beta2 values for optimizers. + + Args: + draw: The draw function provided by `hypothesis`. + + Returns: + A tuple of beta1 and beta2 values. + """ + beta1 = draw(st.floats(min_value=0.0, max_value=0.95, exclude_max=True, exclude_min=True)) + beta2 = draw(st.floats(min_value=beta1, max_value=0.999, exclude_max=True, exclude_min=True)) + return beta1, beta2 + + +HYPERPARAM_OVERRIDE_STRATEGIES: Final[dict[str, st.SearchStrategy]] = { + "lr": _ZERO_TO_ONE_FLOATS, + "weight_decay": _ZERO_TO_ONE_FLOATS, + "momentum": _ZERO_TO_ONE_FLOATS, + "betas": betas(), + "lr_decay": _ZERO_TO_ONE_FLOATS, + "eps": _ZERO_TO_ONE_FLOATS, + "centered": st.booleans(), + "rho": _ZERO_TO_ONE_FLOATS, + "momentum_decay": _ZERO_TO_ONE_FLOATS, + "etas": st.tuples( + st.floats(min_value=0.0, max_value=1.0, exclude_min=True, exclude_max=True), + st.floats(min_value=1.0, max_value=2.0, exclude_min=True, exclude_max=True), + ), + "dampening": _ZERO_TO_ONE_FLOATS, + "nesterov": st.booleans(), + "initial_accumulator_value": _ZERO_TO_ONE_FLOATS, +} + + +def optimizer_type_strategy( + allowed_optimizer_types: Sequence[type[torch.optim.Optimizer]] | None = None, +) -> st.SearchStrategy[type[torch.optim.Optimizer]]: + """Strategy for generating torch optimizers. + + Args: + allowed_optimizer_types: A sequence of optimizers to sample from. If None, all available optimizers are sampled. + + Returns: + A strategy for generating torch optimizers. + """ + if allowed_optimizer_types is None: + allowed_optimizer_types = OPTIMIZERS + return st.sampled_from(allowed_optimizer_types) + + +@st.composite +def optimizer_strategy( + draw: st.DrawFn, + optimizer_type: type[torch.optim.Optimizer] | st.SearchStrategy[type[torch.optim.Optimizer]] = None, + **kwargs, +) -> st.SearchStrategy[OptimizerConstructorWithOnlyParameters]: + """Strategy for generating torch optimizers. + + Args: + draw: The draw function provided by `hypothesis`. + optimizer_type: The optimizer type or a strategy for generating optimizer types. + kwargs: Keyword arguments to pass to the optimizer constructor. If a keyword argument is a strategy, it will be + drawn from. + """ + if optimizer_type is None: + optimizer_type = optimizer_type_strategy() + if isinstance(optimizer_type, st.SearchStrategy): + optimizer_type = draw(optimizer_type) + + sig = inspection_util.infer_signature_annotations(optimizer_type.__init__) + for param in sig.parameters.values(): + if param.name in kwargs and isinstance(kwargs[param.name], st.SearchStrategy): + kwargs[param.name] = draw(kwargs[param.name]) + elif param.annotation is inspect.Parameter.empty: + continue + elif param.name in HYPERPARAM_OVERRIDE_STRATEGIES: + kwargs[param.name] = draw(HYPERPARAM_OVERRIDE_STRATEGIES[param.name]) + elif param.annotation in (float, int): + kwargs[param.name] = draw(_ZERO_TO_ONE_FLOATS) + else: + kwargs[param.name] = draw(st.from_type(param.annotation)) + if "nesterov" in kwargs and kwargs["nesterov"] and "momentum" in kwargs: + kwargs["dampening"] = 0 + kwargs.pop("self", None) # Remove self if a type was inferred + kwargs.pop("params", None) # Remove params if a type was inferred + + hypothesis.note(f"Chosen optimizer type: {optimizer_type}") + hypothesis.note(f"Chosen optimizer hyperparameters: {kwargs}") + + def optimizer_factory(params: Sequence[torch.nn.Parameter]) -> torch.optim.Optimizer: + return optimizer_type(params, **kwargs) + + return optimizer_factory diff --git a/tests/unit/test_optim.py b/tests/unit/test_optim.py new file mode 100644 index 0000000..1d770e1 --- /dev/null +++ b/tests/unit/test_optim.py @@ -0,0 +1,50 @@ +"""Tests for the `hypothesis_torch.optim` module.""" + +import unittest + +import hypothesis_torch +import hypothesis +from hypothesis import strategies as st + +import torch + + +class TestOptimizerTypeStrategy(unittest.TestCase): + """Tests for the `optimizer_type_strategy` function.""" + + @hypothesis.given(optimizer_type=hypothesis_torch.optimizer_type_strategy()) + def test_optimizer_type_strategy(self, optimizer_type: type[torch.optim.Optimizer]) -> None: + """Test that `optimizer_type_strategy` generates valid optimizer types.""" + self.assertTrue(issubclass(optimizer_type, torch.optim.Optimizer)) + self.assertNotEqual(optimizer_type, torch.optim.Optimizer) + self.assertNotIn("NewCls", optimizer_type.__name__) + + @hypothesis.given( + optimizer_type=hypothesis_torch.optimizer_type_strategy(allowed_optimizer_types=[torch.optim.Adam]) + ) + def test_optimizer_type_strategy_allowed_optimizer_types(self, optimizer_type: type[torch.optim.Optimizer]) -> None: + """Test that `optimizer_type_strategy` generates optimizer types when specifying `allowed_optimizer_types`.""" + self.assertEqual(optimizer_type, torch.optim.Adam) + + +class TestOptimizerStrategy(unittest.TestCase): + """Tests for the `optimizer_strategy` function.""" + + @hypothesis.settings(deadline=None) + @hypothesis.given( + optimizer_constructor=hypothesis_torch.optimizer_strategy(), + module=hypothesis_torch.linear_network_strategy( + input_shape=(1, 1), + output_shape=(1, 1), + activation_layer=torch.nn.ReLU(), + hidden_layer_size=st.integers(min_value=1, max_value=10), + num_hidden_layers=st.integers(min_value=1, max_value=10), + device=hypothesis_torch.device_strategy(), + ), + ) + def test_optimizer_strategy( + self, optimizer_constructor: hypothesis_torch.OptimizerConstructorWithOnlyParameters, module: torch.nn.Module + ) -> None: + """Test that `optimizer_strategy` generates valid optimizers.""" + optimizer = optimizer_constructor(module.parameters()) + self.assertIsInstance(optimizer, torch.optim.Optimizer)