diff --git a/hypothesis_torch/dtype.py b/hypothesis_torch/dtype.py index d97b2e2..9b0884c 100644 --- a/hypothesis_torch/dtype.py +++ b/hypothesis_torch/dtype.py @@ -50,7 +50,7 @@ torch.float16: np.float16, torch.float32: np.float32, torch.float64: np.float64, - torch.bfloat16: np.float32, # Numpy does not have a bf16, but it has the same dynamic range as f32 + torch.bfloat16: np.float32, # Numpy does not have a bf16, but it is a strict subset of fp32 torch.complex64: complex, torch.complex128: complex, torch.bool: np.bool_, @@ -60,7 +60,7 @@ float_width_map: Final[Mapping[torch.dtype, Literal[16, 32, 64]]] = { torch.float16: 16, - torch.bfloat16: 32, # Numpy does not have a bf16, but it has the same dynamic range as f32 + torch.bfloat16: 32, # Numpy does not have a bf16, but it is a strict subset of f32 torch.float32: 32, torch.float64: 64, } diff --git a/hypothesis_torch/tensor.py b/hypothesis_torch/tensor.py index 85547f3..3b6d1aa 100644 --- a/hypothesis_torch/tensor.py +++ b/hypothesis_torch/tensor.py @@ -95,14 +95,14 @@ def tensor_strategy( if isinstance(device, st.SearchStrategy): device = draw(device) # MPS devices do not support tensors with dtype torch.float64 and bfloat16 - hypothesis.assume(not (device is not None and device.type == "mps" and dtype in (torch.float64, torch.bfloat16))) + hypothesis.assume(device is None or device.type != "mps" or dtype not in (torch.float64, torch.bfloat16)) if layout is None: layout = st.from_type(torch.layout) if isinstance(layout, st.SearchStrategy): layout = draw(layout) # MPS devices do not support sparse tensors - hypothesis.assume(not (device is not None and device.type == "mps" and layout == torch.sparse_coo)) + hypothesis.assume(device is None or device.type != "mps" or layout != torch.sparse_coo) # If the dtype is an integer, we need to make sure that the elements are integers within the dtype's range if dtype in dtype_module.INT_DTYPES and isinstance(elements, st.SearchStrategy): @@ -112,6 +112,14 @@ def tensor_strategy( # If the dtype is a float, then we need to make sure that only elements that can be represented exactly are # generated if dtype in dtype_module.FLOAT_DTYPES and elements is not None: + if dtype == torch.bfloat16: + # Since we do not directly support bfloat16 in numpy, we will generate float32. + # This still means that we will occasionally generate values that exceed the max/min of bfloat16. + # All other values (within the range) will be simply truncated below when casting the numpy array to + # a bfloat16 tensor. + bfloat16_info = torch.finfo(torch.bfloat16) + elements = elements.filter(lambda x: bfloat16_info.min <= x <= bfloat16_info.max) + width = dtype_module.float_width_map[dtype] if width < 64: @@ -129,9 +137,9 @@ def downcast(x: float) -> float: The downcasted float. """ try: - return float_of(x, width) + return hypothesis.internal.floats.float_of(x, width) except OverflowError: # pragma: no cover - reject() + hypothesis.reject() elements = elements.map(downcast)