Skip to content

Commit

Permalink
Merge branch 'main' into increase-coverage-for-tensor-strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
qthequartermasterman committed Sep 6, 2024
2 parents 6334067 + 92cd8a0 commit e4bcc8f
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 125 deletions.
6 changes: 5 additions & 1 deletion .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ updates:
- package-ecosystem: "pip"
directory: "/" # Location of package manifests
schedule:
interval: "daily"
interval: "weekly"
target-branch: "main"
reviewers:
- "qthequartermasterman"
Expand All @@ -19,10 +19,14 @@ updates:
exclude-patterns:
- "hypothesis"
- "torch*"
- "transformers"
direct-dependencies:
patterns:
- "hypothesis"
- "torch*"
optional-dependencies:
patterns:
- "transformers"
- package-ecosystem: github-actions
directory: /
schedule:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build-reusable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
- name: Install Specific PyTorch version
if: ${{ inputs.torch-version != '' }}
run: |
uv pip install torch==${{ inputs.torch-version }}
uv pip install 'torch==${{ inputs.torch-version }}+cpu' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Install dependencies
run: |
uv pip install ".[dev,huggingface]"
Expand Down
13 changes: 6 additions & 7 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@ permissions:
jobs:
build:
strategy:
fail-fast: false
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '']
exclude:
- python-version: '3.12'
torch-version: '2.1.2'
uses: ./.github/workflows/build-reusable.yml
with:
python-version: ${{ matrix.python-version }}
torch-version: ${{ matrix.torch-version }}
lint: ${{ matrix.python-version == '3.9' }}

build-pytorch-1-13:
uses: ./.github/workflows/build-reusable.yml
with:
python-version: '3.9'
lint: false
torch-version: '1.13.1'
334 changes: 229 additions & 105 deletions CHANGELOG.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion hypothesis_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
lacks built-in support for Pytorch tensors and modules, so this library provides strategies for generating them.
"""

__version__ = "0.7.11"
__version__ = "0.7.13"
import importlib.util

from hypothesis_torch.device import (
Expand Down
19 changes: 11 additions & 8 deletions hypothesis_torch/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@

from hypothesis_torch import inspection_util

OptimizerConstructorWithOnlyParameters: TypeAlias = Callable[[Iterable[torch.nn.Parameter]], torch.optim.Optimizer]
# We do an alias here to avoid some type checking issues with torch 2.4.0
Optimizer: TypeAlias = torch.optim.Optimizer # pyright: ignore[reportPrivateImportUsage]

OPTIMIZERS: Final[tuple[type[torch.optim.Optimizer], ...]] = tuple(
OptimizerConstructorWithOnlyParameters: TypeAlias = Callable[[Iterable[torch.nn.Parameter]], Optimizer]

OPTIMIZERS: Final[tuple[type[Optimizer], ...]] = tuple(
optimizer
for optimizer in inspection_util.get_all_subclasses(torch.optim.Optimizer)
if optimizer is not torch.optim.Optimizer and "NewCls" not in optimizer.__name__
for optimizer in inspection_util.get_all_subclasses(Optimizer)
if optimizer is not Optimizer and "NewCls" not in optimizer.__name__
)

_ZERO_TO_ONE_FLOATS: Final[st.SearchStrategy[float]] = st.floats(
Expand Down Expand Up @@ -63,8 +66,8 @@ def betas(draw: st.DrawFn) -> tuple[float, float]:


def optimizer_type_strategy(
allowed_optimizer_types: Sequence[type[torch.optim.Optimizer]] | None = None,
) -> st.SearchStrategy[type[torch.optim.Optimizer]]:
allowed_optimizer_types: Sequence[type[Optimizer]] | None = None,
) -> st.SearchStrategy[type[Optimizer]]:
"""Strategy for generating torch optimizers.
Args:
Expand All @@ -81,7 +84,7 @@ def optimizer_type_strategy(
@st.composite
def optimizer_strategy(
draw: st.DrawFn,
optimizer_type: type[torch.optim.Optimizer] | st.SearchStrategy[type[torch.optim.Optimizer]] | None = None,
optimizer_type: type[Optimizer] | st.SearchStrategy[type[Optimizer]] | None = None,
**kwargs: Any, # noqa: ANN401
) -> OptimizerConstructorWithOnlyParameters:
"""Strategy for generating torch optimizers.
Expand Down Expand Up @@ -121,7 +124,7 @@ def optimizer_strategy(
hypothesis.note(f"Chosen optimizer type: {optimizer_type}")
hypothesis.note(f"Chosen optimizer hyperparameters: {kwargs}")

def optimizer_factory(params: Iterable[torch.nn.Parameter]) -> torch.optim.Optimizer:
def optimizer_factory(params: Iterable[torch.nn.Parameter]) -> Optimizer:
return optimizer_type(params, **kwargs)

return optimizer_factory
3 changes: 2 additions & 1 deletion hypothesis_torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ def downcast(x: float) -> float:
hypothesis.assume(memory_format != torch.channels_last or len(tensor.shape) == 4)
# channel_last_3d memory format is only supported for 5D tensors
hypothesis.assume(memory_format != torch.channels_last_3d or len(tensor.shape) == 5)
tensor = tensor.to(memory_format=memory_format)
# Pyright/mypy falsely reports an error here on py3.9 torch 2.1.2.
tensor = tensor.to(memory_format=memory_format) # type: ignore

return tensor

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ description = "Hypothesis strategies for various Pytorch structures, including t
dynamic = ["version"]
dependencies = [
"hypothesis>=6.0.0",
"torch>=1.13.0",
"torch>=2.1.0",
]
requires-python = ">=3.9"
authors=[{name="Andrew P. Sansom", email="[email protected]"}]
Expand Down

0 comments on commit e4bcc8f

Please sign in to comment.