Skip to content

Commit

Permalink
fix: 🐛 do not try to use the device as a context manager in torch<2.
Browse files Browse the repository at this point in the history
  • Loading branch information
qthequartermasterman committed May 10, 2024
1 parent ef0e47c commit 3114162
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion hypothesis_torch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import contextlib
from typing import TypeVar, Sequence, Mapping

from hypothesis_torch import inspection_util
Expand Down Expand Up @@ -30,6 +31,22 @@
POSITIVE_INTS = st.integers(min_value=1)


def _context_manager(device: torch.device) -> torch.device | contextlib.nullcontext:
"""Return a context manager for the device.
For torch>=2, this is a no-op. The default device will bet set to the `device` inside the returned context.
For torch<2, however, this returns an empty context manager. No default device will be set. Consequently, manual
casting will be necessary at the end of the context.
Args:
device: The device to use.
Returns:
A context manager for the device.
"""
return device if hasattr(device, "__enter__") else contextlib.nullcontext()


@st.composite
def lower_upper_strategy(draw: st.DrawFn) -> tuple[float, float]:
"""Strategy for generating a pair of floats where the first is less than the second.
Expand Down Expand Up @@ -209,7 +226,7 @@ def linear_network_strategy(
if isinstance(num_hidden_layers, st.SearchStrategy):
num_hidden_layers = draw(num_hidden_layers)

with device:
with _context_manager(device):
interior_layer_sizes = draw(
st.lists(
hidden_layer_size,
Expand Down

0 comments on commit 3114162

Please sign in to comment.