Skip to content

Commit

Permalink
Merge pull request #18 from qthequartermasterman/optimizers
Browse files Browse the repository at this point in the history
feat: ✨ add strategies for torch optimizers
  • Loading branch information
qthequartermasterman authored May 9, 2024
2 parents 7ea76aa + 4eb2f4d commit fbefa06
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 43 deletions.
1 change: 1 addition & 0 deletions hypothesis_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from hypothesis_torch.layout import layout_strategy
from hypothesis_torch.memory_format import memory_format_strategy
from hypothesis_torch.module import linear_network_strategy, same_shape_activation_strategy
from hypothesis_torch.optim import optimizer_strategy, optimizer_type_strategy, OptimizerConstructorWithOnlyParameters
from hypothesis_torch.register_random_torch_state import TORCH_RANDOM_WRAPPER
from hypothesis_torch.tensor import tensor_strategy

Expand Down
23 changes: 23 additions & 0 deletions hypothesis_torch/inspection_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

import inspect
from typing import Callable, TypeVar
from hypothesis import strategies as st

import torch

T = TypeVar("T")

Expand Down Expand Up @@ -59,3 +62,23 @@ def get_all_subclasses(cls: type[T]) -> list[type[T]]:
all_subclasses.append(subclass)
all_subclasses.extend(get_all_subclasses(subclass))
return all_subclasses


@st.composite
def signature_to_strategy(draw: st.DrawFn, constructor: type[T], *args, **kwargs) -> T:
"""Strategy for generating instances of a class by drawing values for its constructor.
Args:
draw: The draw function provided by `hypothesis`.
constructor: The class to generate an instance of.
args: Positional arguments to pass to the constructor. If an argument is a strategy, it will be drawn from.
kwargs: Keyword arguments to pass to the constructor. If a keyword argument is a strategy, it will be drawn
from.
Returns:
An instance of the class.
"""
args_drawn = [draw(strategy) for strategy in args]
kwargs_drawn = {k: draw(strategy) for k, strategy in kwargs.items()}
return constructor(*args_drawn, **kwargs_drawn)
71 changes: 28 additions & 43 deletions hypothesis_torch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import TypeVar, Sequence, Mapping

from hypothesis_torch import inspection_util
import torch
from hypothesis import strategies as st
from torch import nn
Expand Down Expand Up @@ -60,26 +61,6 @@ def rrelu_strategy(draw: st.DrawFn) -> nn.RReLU:
return nn.RReLU(lower, upper, inplace)


@st.composite
def signature_to_strategy(draw: st.DrawFn, constructor: type[T], *args, **kwargs) -> T:
"""Strategy for generating instances of a class by drawing values for its constructor.
Args:
draw: The draw function provided by `hypothesis`.
constructor: The class to generate an instance of.
args: Positional arguments to pass to the constructor. If an argument is a strategy, it will be drawn from.
kwargs: Keyword arguments to pass to the constructor. If a keyword argument is a strategy, it will be drawn
from.
Returns:
An instance of the class.
"""
args_drawn = [draw(strategy) for strategy in args]
kwargs_drawn = {k: draw(strategy) for k, strategy in kwargs.items()}
return constructor(*args_drawn, **kwargs_drawn)


@st.composite
def hard_tanh_strategy(draw: st.DrawFn) -> nn.Hardtanh:
"""Strategy for generating instances of `nn.Hardtanh` by drawing values for its constructor.
Expand All @@ -98,38 +79,42 @@ def hard_tanh_strategy(draw: st.DrawFn) -> nn.Hardtanh:


activation_strategies: dict[type[nn.Module], st.SearchStrategy[nn.Module]] = {
nn.Identity: signature_to_strategy(nn.Identity),
nn.ELU: signature_to_strategy(nn.ELU, alpha=SENSIBLE_FLOATS, inplace=st.booleans()),
nn.Hardshrink: signature_to_strategy(nn.Hardshrink, lambd=SENSIBLE_FLOATS),
nn.Hardsigmoid: signature_to_strategy(nn.Hardsigmoid, inplace=st.booleans()),
nn.Identity: inspection_util.signature_to_strategy(nn.Identity),
nn.ELU: inspection_util.signature_to_strategy(nn.ELU, alpha=SENSIBLE_FLOATS, inplace=st.booleans()),
nn.Hardshrink: inspection_util.signature_to_strategy(nn.Hardshrink, lambd=SENSIBLE_FLOATS),
nn.Hardsigmoid: inspection_util.signature_to_strategy(nn.Hardsigmoid, inplace=st.booleans()),
nn.Hardtanh: hard_tanh_strategy(),
nn.Hardswish: signature_to_strategy(nn.Hardswish, inplace=st.booleans()),
nn.LeakyReLU: signature_to_strategy(nn.LeakyReLU, negative_slope=SENSIBLE_FLOATS, inplace=st.booleans()),
nn.LogSigmoid: signature_to_strategy(nn.LogSigmoid),
nn.Hardswish: inspection_util.signature_to_strategy(nn.Hardswish, inplace=st.booleans()),
nn.LeakyReLU: inspection_util.signature_to_strategy(
nn.LeakyReLU, negative_slope=SENSIBLE_FLOATS, inplace=st.booleans()
),
nn.LogSigmoid: inspection_util.signature_to_strategy(nn.LogSigmoid),
# TODO: nn.MultiheadAttention, although in the `Non-linear activations` section, does not have the same shape
# inside and outside
# TODO: nn.PReLU(num_parameters=1, init=0.25, device=None, dtype=None)
# TODO: PReLU might depend on the input shape
# TODO: num_parameters (int) – number of a to learn. Although it takes an int as input, there is only two
# values are legitimate: 1, or the number of channels at input. Default: 1
nn.PReLU: signature_to_strategy(nn.PReLU, num_parameters=st.just(1), init=SENSIBLE_FLOATS),
nn.ReLU: signature_to_strategy(nn.ReLU, inplace=st.booleans()),
nn.ReLU6: signature_to_strategy(nn.ReLU6, inplace=st.booleans()),
nn.PReLU: inspection_util.signature_to_strategy(nn.PReLU, num_parameters=st.just(1), init=SENSIBLE_FLOATS),
nn.ReLU: inspection_util.signature_to_strategy(nn.ReLU, inplace=st.booleans()),
nn.ReLU6: inspection_util.signature_to_strategy(nn.ReLU6, inplace=st.booleans()),
nn.RReLU: rrelu_strategy(),
nn.SELU: signature_to_strategy(nn.SELU, inplace=st.booleans()),
nn.CELU: signature_to_strategy(
nn.SELU: inspection_util.signature_to_strategy(nn.SELU, inplace=st.booleans()),
nn.CELU: inspection_util.signature_to_strategy(
nn.CELU, alpha=SENSIBLE_FLOATS.filter(lambda x: abs(x) > 1e-5), inplace=st.booleans()
),
nn.GELU: signature_to_strategy(nn.GELU, approximate=st.sampled_from(["none", "tanh"])),
nn.Sigmoid: signature_to_strategy(nn.Sigmoid),
nn.SiLU: signature_to_strategy(nn.SiLU, inplace=st.booleans()),
nn.Mish: signature_to_strategy(nn.Mish, inplace=st.booleans()),
nn.Softplus: signature_to_strategy(nn.Softplus, beta=SENSIBLE_FLOATS, threshold=SENSIBLE_POSITIVE_FLOATS),
nn.Softshrink: signature_to_strategy(nn.Softshrink, lambd=SENSIBLE_POSITIVE_FLOATS),
nn.Softsign: signature_to_strategy(nn.Softsign),
nn.Tanh: signature_to_strategy(nn.Tanh),
nn.Tanhshrink: signature_to_strategy(nn.Tanhshrink),
nn.Threshold: signature_to_strategy(
nn.GELU: inspection_util.signature_to_strategy(nn.GELU, approximate=st.sampled_from(["none", "tanh"])),
nn.Sigmoid: inspection_util.signature_to_strategy(nn.Sigmoid),
nn.SiLU: inspection_util.signature_to_strategy(nn.SiLU, inplace=st.booleans()),
nn.Mish: inspection_util.signature_to_strategy(nn.Mish, inplace=st.booleans()),
nn.Softplus: inspection_util.signature_to_strategy(
nn.Softplus, beta=SENSIBLE_FLOATS, threshold=SENSIBLE_POSITIVE_FLOATS
),
nn.Softshrink: inspection_util.signature_to_strategy(nn.Softshrink, lambd=SENSIBLE_POSITIVE_FLOATS),
nn.Softsign: inspection_util.signature_to_strategy(nn.Softsign),
nn.Tanh: inspection_util.signature_to_strategy(nn.Tanh),
nn.Tanhshrink: inspection_util.signature_to_strategy(nn.Tanhshrink),
nn.Threshold: inspection_util.signature_to_strategy(
nn.Threshold, threshold=SENSIBLE_FLOATS, value=SENSIBLE_FLOATS, inplace=st.booleans()
),
# TODO: nn.GLU depends on the input shape
Expand Down Expand Up @@ -206,7 +191,7 @@ def linear_network_strategy(
if isinstance(device, st.SearchStrategy):
device = draw(device)

if isinstance(activation_layer, nn.Module):
if not isinstance(activation_layer, st.SearchStrategy):
activation_layer_strategy = st.just(activation_layer)
else:
activation_layer_strategy = activation_layer
Expand Down
121 changes: 121 additions & 0 deletions hypothesis_torch/optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Strategies for generating torch optimizers."""

from __future__ import annotations

from typing import Sequence, Final, Callable, Iterator
from typing_extensions import TypeAlias

import hypothesis
import torch.optim
import inspect

from hypothesis import strategies as st
from hypothesis_torch import inspection_util

OptimizerConstructorWithOnlyParameters: TypeAlias = Callable[[Iterator[torch.nn.Parameter]], torch.optim.Optimizer]

OPTIMIZERS: Final[tuple[type[torch.optim.Optimizer], ...]] = tuple(
optimizer
for optimizer in inspection_util.get_all_subclasses(torch.optim.Optimizer)
if optimizer is not torch.optim.Optimizer and "NewCls" not in optimizer.__name__
)

_ZERO_TO_ONE_FLOATS: Final[st.SearchStrategy[float]] = st.floats(
min_value=0.0, max_value=1.0, exclude_max=True, exclude_min=True
)


@st.composite
def betas(draw: st.DrawFn) -> tuple[float, float]:
"""Strategy for generating beta1 and beta2 values for optimizers.
Args:
draw: The draw function provided by `hypothesis`.
Returns:
A tuple of beta1 and beta2 values.
"""
beta1 = draw(st.floats(min_value=0.0, max_value=0.95, exclude_max=True, exclude_min=True))
beta2 = draw(st.floats(min_value=beta1, max_value=0.999, exclude_max=True, exclude_min=True))
return beta1, beta2


HYPERPARAM_OVERRIDE_STRATEGIES: Final[dict[str, st.SearchStrategy]] = {
"lr": _ZERO_TO_ONE_FLOATS,
"weight_decay": _ZERO_TO_ONE_FLOATS,
"momentum": _ZERO_TO_ONE_FLOATS,
"betas": betas(),
"lr_decay": _ZERO_TO_ONE_FLOATS,
"eps": _ZERO_TO_ONE_FLOATS,
"centered": st.booleans(),
"rho": _ZERO_TO_ONE_FLOATS,
"momentum_decay": _ZERO_TO_ONE_FLOATS,
"etas": st.tuples(
st.floats(min_value=0.0, max_value=1.0, exclude_min=True, exclude_max=True),
st.floats(min_value=1.0, max_value=2.0, exclude_min=True, exclude_max=True),
),
"dampening": _ZERO_TO_ONE_FLOATS,
"nesterov": st.booleans(),
"initial_accumulator_value": _ZERO_TO_ONE_FLOATS,
}


def optimizer_type_strategy(
allowed_optimizer_types: Sequence[type[torch.optim.Optimizer]] | None = None,
) -> st.SearchStrategy[type[torch.optim.Optimizer]]:
"""Strategy for generating torch optimizers.
Args:
allowed_optimizer_types: A sequence of optimizers to sample from. If None, all available optimizers are sampled.
Returns:
A strategy for generating torch optimizers.
"""
if allowed_optimizer_types is None:
allowed_optimizer_types = OPTIMIZERS
return st.sampled_from(allowed_optimizer_types)


@st.composite
def optimizer_strategy(
draw: st.DrawFn,
optimizer_type: type[torch.optim.Optimizer] | st.SearchStrategy[type[torch.optim.Optimizer]] = None,
**kwargs,
) -> st.SearchStrategy[OptimizerConstructorWithOnlyParameters]:
"""Strategy for generating torch optimizers.
Args:
draw: The draw function provided by `hypothesis`.
optimizer_type: The optimizer type or a strategy for generating optimizer types.
kwargs: Keyword arguments to pass to the optimizer constructor. If a keyword argument is a strategy, it will be
drawn from.
"""
if optimizer_type is None:
optimizer_type = optimizer_type_strategy()
if isinstance(optimizer_type, st.SearchStrategy):
optimizer_type = draw(optimizer_type)

sig = inspection_util.infer_signature_annotations(optimizer_type.__init__)
for param in sig.parameters.values():
if param.name in kwargs and isinstance(kwargs[param.name], st.SearchStrategy):
kwargs[param.name] = draw(kwargs[param.name])
elif param.annotation is inspect.Parameter.empty:
continue
elif param.name in HYPERPARAM_OVERRIDE_STRATEGIES:
kwargs[param.name] = draw(HYPERPARAM_OVERRIDE_STRATEGIES[param.name])
elif param.annotation in (float, int):
kwargs[param.name] = draw(_ZERO_TO_ONE_FLOATS)
else:
kwargs[param.name] = draw(st.from_type(param.annotation))
if "nesterov" in kwargs and kwargs["nesterov"] and "momentum" in kwargs:
kwargs["dampening"] = 0
kwargs.pop("self", None) # Remove self if a type was inferred
kwargs.pop("params", None) # Remove params if a type was inferred

hypothesis.note(f"Chosen optimizer type: {optimizer_type}")
hypothesis.note(f"Chosen optimizer hyperparameters: {kwargs}")

def optimizer_factory(params: Sequence[torch.nn.Parameter]) -> torch.optim.Optimizer:
return optimizer_type(params, **kwargs)

return optimizer_factory
50 changes: 50 additions & 0 deletions tests/unit/test_optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Tests for the `hypothesis_torch.optim` module."""

import unittest

import hypothesis_torch
import hypothesis
from hypothesis import strategies as st

import torch


class TestOptimizerTypeStrategy(unittest.TestCase):
"""Tests for the `optimizer_type_strategy` function."""

@hypothesis.given(optimizer_type=hypothesis_torch.optimizer_type_strategy())
def test_optimizer_type_strategy(self, optimizer_type: type[torch.optim.Optimizer]) -> None:
"""Test that `optimizer_type_strategy` generates valid optimizer types."""
self.assertTrue(issubclass(optimizer_type, torch.optim.Optimizer))
self.assertNotEqual(optimizer_type, torch.optim.Optimizer)
self.assertNotIn("NewCls", optimizer_type.__name__)

@hypothesis.given(
optimizer_type=hypothesis_torch.optimizer_type_strategy(allowed_optimizer_types=[torch.optim.Adam])
)
def test_optimizer_type_strategy_allowed_optimizer_types(self, optimizer_type: type[torch.optim.Optimizer]) -> None:
"""Test that `optimizer_type_strategy` generates optimizer types when specifying `allowed_optimizer_types`."""
self.assertEqual(optimizer_type, torch.optim.Adam)


class TestOptimizerStrategy(unittest.TestCase):
"""Tests for the `optimizer_strategy` function."""

@hypothesis.settings(deadline=None)
@hypothesis.given(
optimizer_constructor=hypothesis_torch.optimizer_strategy(),
module=hypothesis_torch.linear_network_strategy(
input_shape=(1, 1),
output_shape=(1, 1),
activation_layer=torch.nn.ReLU(),
hidden_layer_size=st.integers(min_value=1, max_value=10),
num_hidden_layers=st.integers(min_value=1, max_value=10),
device=hypothesis_torch.device_strategy(),
),
)
def test_optimizer_strategy(
self, optimizer_constructor: hypothesis_torch.OptimizerConstructorWithOnlyParameters, module: torch.nn.Module
) -> None:
"""Test that `optimizer_strategy` generates valid optimizers."""
optimizer = optimizer_constructor(module.parameters())
self.assertIsInstance(optimizer, torch.optim.Optimizer)

4 comments on commit fbefa06

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
hypothesis_torch
   device.py19195%43
   huggingface.py1057430%102–106, 127–213, 239–257
   inspection_util.py331555%28–33, 38–47, 82–84
   module.py774640%44–46, 59–61, 75–78, 146, 150, 179–225, 236–241
   optim.py472743%38–40, 93–121
   tensor.py796024%88–192
   utils.py17759%47–53
TOTAL45323049% 

Tests Skipped Failures Errors Time
1 0 💤 0 ❌ 1 🔥 2.363s ⏱️

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
hypothesis_torch
   device.py19195%43
   huggingface.py1052774%104–106, 130, 136, 153, 160, 164, 168, 172–173, 176, 187–190, 192, 194–196, 199, 202, 206, 211, 240, 243, 251
   inspection_util.py33197%30
   module.py771581%146, 150, 180, 182, 186, 188, 197, 201–203, 236–241
   optim.py47198%101
   tensor.py79692%147, 155, 157, 173, 179, 181
TOTAL4535189% 

Tests Skipped Failures Errors Time
1256 1205 💤 0 ❌ 0 🔥 1m 16s ⏱️

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
hypothesis_torch
   device.py19195%43
   huggingface.py1052774%104–106, 130, 136, 153, 160, 164, 168, 172–173, 176, 187–190, 192, 194–196, 199, 202, 206, 211, 240, 243, 251
   inspection_util.py33197%30
   module.py771581%146, 150, 180, 182, 186, 188, 197, 201–203, 236–241
   optim.py47198%101
   tensor.py77692%147, 155, 157, 173, 179, 181
   utils.py17194%33
TOTAL4435288% 

Tests Skipped Failures Errors Time
1256 1205 💤 0 ❌ 0 🔥 1m 20s ⏱️

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
hypothesis_torch
   device.py19195%43
   huggingface.py1052774%104–106, 130, 136, 153, 160, 164, 168, 172–173, 176, 187–190, 192, 194–196, 199, 202, 206, 211, 240, 243, 251
   inspection_util.py33197%30
   module.py772173%44–46, 59–61, 146, 150, 180, 182, 186, 188, 197, 201–203, 236–241
   optim.py47198%101
   tensor.py79692%147, 155, 157, 173, 179, 181
TOTAL4535787% 

Tests Skipped Failures Errors Time
1256 1205 💤 0 ❌ 0 🔥 1m 25s ⏱️

Please sign in to comment.