Skip to content

Commit

Permalink
convert to float32 to keep pydata#1840 in sync
Browse files Browse the repository at this point in the history
  • Loading branch information
kmuehlbauer committed Apr 4, 2023
1 parent 19ef234 commit e877aa7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
18 changes: 10 additions & 8 deletions xarray/coding/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ def _choose_float_dtype(
) -> type[np.floating[Any]]:
# check scale/offset first to derive dtype
# see https://github.com/pydata/xarray/issues/5597#issuecomment-879561954
scale_factor = mapping.get("scale_factor", False)
add_offset = mapping.get("add_offset", False)
scale_factor = mapping.get("scale_factor")
add_offset = mapping.get("add_offset")
if scale_factor or add_offset:
# get the maximum itemsize from scale_factor/add_offset to determine
# the needed floating point type
Expand All @@ -320,7 +320,7 @@ def _choose_float_dtype(
# but a large integer offset could lead to loss of precision.
# Sensitivity analysis can be tricky, so we just use a float64
# if there's any offset at all - better unoptimised than wrong!
if maxsize == 4 and np.issubdtype(add_offset_type, np.floating):
if maxsize == 4 or not np.issubdtype(add_offset_type, np.floating):
return np.float32
else:
return np.float64
Expand Down Expand Up @@ -350,12 +350,14 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
if scale_factor or add_offset:
dtype = _choose_float_dtype(data.dtype, attrs)
data = data.astype(dtype=dtype, copy=True)
if add_offset:
data -= add_offset
if scale_factor:
data /= scale_factor
if add_offset:
data -= add_offset
if scale_factor:
data /= scale_factor

return Variable(dims, data, attrs, encoding, fastpath=True)
return Variable(dims, data, attrs, encoding, fastpath=True)
else:
return variable

def decode(self, variable: Variable, name: T_Name = None) -> Variable:
dims, data, attrs, encoding = unpack_for_decoding(variable)
Expand Down
7 changes: 4 additions & 3 deletions xarray/tests/test_coding.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,11 @@ def test_coder_roundtrip() -> None:
assert_identical(original, roundtripped)


@pytest.mark.parametrize("dtype", "u1 u2 i1 i2 f2 f4".split())
def test_scaling_converts_to_float32(dtype) -> None:
@pytest.mark.parametrize("unpacked_dtype", [np.float32, np.float64, np.int32])
@pytest.mark.parametrize("packed_dtype", "u1 u2 i1 i2 f2 f4".split())
def test_scaling_converts_to_float32(packed_dtype, unpacked_dtype) -> None:
original = xr.Variable(
("x",), np.arange(10, dtype=dtype), encoding=dict(scale_factor=10)
("x",), np.arange(10, dtype=packed_dtype), encoding=dict(scale_factor=unpacked_dtype(10))
)
coder = variables.CFScaleOffsetCoder()
encoded = coder.encode(original)
Expand Down

0 comments on commit e877aa7

Please sign in to comment.