From 3114162fb77ed48ecb5d980b4e2b41c6184bc3ae Mon Sep 17 00:00:00 2001 From: qthequartermasterman Date: Thu, 9 May 2024 23:21:02 -0500 Subject: [PATCH] fix: :bug: do not try to use the device as a context manager in torch<2. --- hypothesis_torch/module.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/hypothesis_torch/module.py b/hypothesis_torch/module.py index b6da174..735a14f 100644 --- a/hypothesis_torch/module.py +++ b/hypothesis_torch/module.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextlib from typing import TypeVar, Sequence, Mapping from hypothesis_torch import inspection_util @@ -30,6 +31,22 @@ POSITIVE_INTS = st.integers(min_value=1) +def _context_manager(device: torch.device) -> torch.device | contextlib.nullcontext: + """Return a context manager for the device. + + For torch>=2, this is a no-op. The default device will bet set to the `device` inside the returned context. + For torch<2, however, this returns an empty context manager. No default device will be set. Consequently, manual + casting will be necessary at the end of the context. + + Args: + device: The device to use. + + Returns: + A context manager for the device. + """ + return device if hasattr(device, "__enter__") else contextlib.nullcontext() + + @st.composite def lower_upper_strategy(draw: st.DrawFn) -> tuple[float, float]: """Strategy for generating a pair of floats where the first is less than the second. @@ -209,7 +226,7 @@ def linear_network_strategy( if isinstance(num_hidden_layers, st.SearchStrategy): num_hidden_layers = draw(num_hidden_layers) - with device: + with _context_manager(device): interior_layer_sizes = draw( st.lists( hidden_layer_size,