Skip to content

Commit

Permalink
Merge pull request #71 from csiro-coasts/timedelta-fill-value
Browse files Browse the repository at this point in the history
Timedelta fill value
  • Loading branch information
mx-moth authored Apr 24, 2023
2 parents 78c67a7 + ba6f6b6 commit 5fe4e6b
Show file tree
Hide file tree
Showing 11 changed files with 136 additions and 50 deletions.
4 changes: 3 additions & 1 deletion docs/releases/development.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
Next release (in development)
=============================

* ...
* Fixed an issue with ``_FillValue`` / ``missing_value``
and variables with non-float types such as ``timedelta64``
(:pr:`71`)
39 changes: 20 additions & 19 deletions src/emsarray/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import xarray as xr
from xarray.core.dtypes import maybe_promote

from emsarray import utils
from emsarray.types import Pathish
Expand Down Expand Up @@ -71,7 +72,6 @@ def mask_grid_dataset(
# file system, at the added expense of having to recombine the dataset
# afterwards.
for key, data_array in dataset.data_vars.items():
logger.debug("DataArray %s", key)
masked_data_array = mask_grid_data_array(mask, data_array)
variable_path = work_path / f"{key}.nc"
mfdataset_names.append(variable_path)
Expand Down Expand Up @@ -130,19 +130,28 @@ def mask_grid_data_array(mask: xr.Dataset, data_array: xr.DataArray) -> xr.DataA
try:
fill_value = find_fill_value(data_array)
except ValueError:
logger.debug(
"Data array %r has no valid fill value, leaving as is",
data_array.name)
return data_array

# Loop through each possible mask
for mask_name, mask_data_array in mask.data_vars.items():
# If every dimension of this mask exists in the data array, apply it
if dimensions >= set(mask_data_array.dims):
logger.debug(
"Masking data array %r with mask %r",
data_array.name, mask_name)
new_data_array = cast(xr.DataArray, data_array.where(mask_data_array, other=fill_value))
new_data_array.attrs = data_array.attrs
new_data_array.encoding = data_array.encoding
return new_data_array

# Fallback, no appropriate mask was found, so don't apply any.
# This generally happens for data arrays such as time, record, x_grid, etc.
logger.debug(
"Data array %r had no relevant mask, leaving as is",
data_array.name)
return data_array


Expand Down Expand Up @@ -182,24 +191,16 @@ def find_fill_value(data_array: xr.DataArray) -> Any:
# constructed a dataset using one...
return np.ma.masked

if '_FillValue' in data_array.encoding:
# The dataset was opened with mask_and_scale=True and a mask has been
# applied. Masked values are now represented as np.nan, not _FillValue.
return np.nan

if '_FillValue' in data_array.attrs:
# The dataset was opened with mask_and_scale=False and a mask has not
# been applied. Masked values should be represented using _FillValue.
return data_array.attrs['_FillValue']

if issubclass(data_array.dtype.type, np.floating):
# NaN is a useful fallback for a _FillValue, but only if the dtype
# is some sort of float. We won't actually _set_ a _FillValue
# attribute though, as that can play havok when trying to save
# existing datasets. xarray gets real grumpy when you have
# a _FillValue and a missing_value, and some existing datasets play
# fast and loose with mixing the two.
return np.nan
attrs = ['_FillValue', 'missing_value']
for attr in attrs:
if attr in data_array.attrs:
# The dataset was opened with mask_and_scale=False and a mask has not
# been applied. Masked values should be represented using _FillValue/missing_value.
return data_array.attrs[attr]

promoted_dtype, fill_value = maybe_promote(data_array.dtype)
if promoted_dtype == data_array.dtype:
return fill_value

raise ValueError("No appropriate fill value found")

Expand Down
5 changes: 4 additions & 1 deletion src/emsarray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from packaging.version import Version
from xarray.coding import times
from xarray.core.common import contains_cftime_datetimes
from xarray.core.dtypes import maybe_promote

from emsarray.types import Pathish

Expand Down Expand Up @@ -233,8 +234,10 @@ def disable_default_fill_value(dataset_or_array: Union[xr.Dataset, xr.DataArray]
The :class:`xarray.Dataset` or :class:`xarray.DataArray` to update
"""
for variable in _get_variables(dataset_or_array):
current_dtype = variable.dtype
promoted_dtype, fill_value = maybe_promote(current_dtype)
if (
issubclass(variable.dtype.type, np.floating)
current_dtype == promoted_dtype
and "_FillValue" not in variable.encoding
and "_FillValue" not in variable.attrs
):
Expand Down
Binary file modified tests/datasets/masking/find_fill_value/float_with_fill_value.nc
Binary file not shown.
Binary file not shown.
Binary file not shown.
50 changes: 44 additions & 6 deletions tests/datasets/masking/find_fill_value/make_datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
#!/usr/bin/env python3

"""
Make some datasets for testing fill values.
Because of how xarray preprocesses variables to apply masks,
it is easier to construct these datasets using the plain netCDF4 library,
save the datasets to disk, and then load them using xarray.
This guarantees that the behaviour in the tests will replicate real-world use.
Running this script will overwrite any datasets already constructed in this directory.
This operation should result in byte-for-byte identical datasets each time it is run.
However each netCDF4 dataset will encode the versions of the
netCDF4, hdf5 and other relevant libraries used to construct the dataset.
If the versions have changed, the script will create new files that git thinks have changed.
"""

import pathlib

import netCDF4
Expand All @@ -8,8 +22,10 @@
here = pathlib.Path(__file__).parent


def make_float_with_fill_value() -> None:
ds = netCDF4.Dataset(here / "float_with_fill_value.nc", "w", "NETCDF4")
def make_float_with_fill_value(
output_path: pathlib.Path = here / "float_with_fill_value.nc"
) -> None:
ds = netCDF4.Dataset(output_path, "w", "NETCDF4")
ds.createDimension("x", 2)
ds.createDimension("y", 2)

Expand All @@ -20,8 +36,10 @@ def make_float_with_fill_value() -> None:
ds.close()


def make_float_with_fill_value_and_offset() -> None:
ds = netCDF4.Dataset(here / "float_with_fill_value_and_offset.nc", "w", "NETCDF4")
def make_float_with_fill_value_and_offset(
output_path: pathlib.Path = here / "float_with_fill_value_and_offset.nc",
) -> None:
ds = netCDF4.Dataset(output_path, "w", "NETCDF4")
ds.createDimension("x", 2)
ds.createDimension("y", 2)

Expand All @@ -34,8 +52,27 @@ def make_float_with_fill_value_and_offset() -> None:
ds.close()


def make_int_with_fill_value_and_offset() -> None:
ds = netCDF4.Dataset(here / "int_with_fill_value_and_offset.nc", "w", "NETCDF4")
def make_timedelta_with_missing_value(
output_path: pathlib.Path = here / "timedelta_with_missing_value.nc",
) -> None:
ds = netCDF4.Dataset(output_path, "w", "NETCDF4")
ds.createDimension("x", 2)
ds.createDimension("y", 2)

missing_value = np.float32(1.e+35)
var = ds.createVariable("var", "f4", ["y", "x"], fill_value=False)
var.missing_value = missing_value
var.units = "days"
var[:] = np.arange(4).reshape((2, 2))
var[1, 1] = missing_value

ds.close()


def make_int_with_fill_value_and_offset(
output_path: pathlib.Path = here / "int_with_fill_value_and_offset.nc",
) -> None:
ds = netCDF4.Dataset(output_path, "w", "NETCDF4")
ds.createDimension("x", 2)
ds.createDimension("y", 2)

Expand All @@ -51,4 +88,5 @@ def make_int_with_fill_value_and_offset() -> None:
if __name__ == '__main__':
make_float_with_fill_value()
make_float_with_fill_value_and_offset()
make_timedelta_with_missing_value()
make_int_with_fill_value_and_offset()
Binary file not shown.
10 changes: 5 additions & 5 deletions tests/masking/test_mask_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ def test_mask_dataset(tmp_path: pathlib.Path):
data=np.random.normal(0, 0.2, (records, j_size, i_size)),
dims=["record", "j_centre", "i_centre"],
attrs={
"units": "metre",
"long_name": "Surface elevation",
"standard_name": "sea_surface_height_above_geoid",
}
"units": "metre",
"long_name": "Surface elevation",
"standard_name": "sea_surface_height_above_geoid",
}
)
temp = xr.DataArray(
data=np.random.normal(12, 0.5, (records, k_size, j_size, i_size)),
Expand Down Expand Up @@ -262,7 +262,7 @@ def test_mask_dataset(tmp_path: pathlib.Path):
assert nc_flag2.shape == (k_size, 4, 3)
flag2_mask = np.stack([np.array([
[0, 0, 0], [0, 0, 0], [0, 0, 1], [0, 1, 1]
])]*k_size).astype(bool)
])] * k_size).astype(bool)
expected: np.ndarray = np.ma.masked_array(
flag2.values[:, 1:5, 1:4].copy(),
mask=flag2_mask,
Expand Down
21 changes: 21 additions & 0 deletions tests/masking/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from numpy.testing import assert_equal

from emsarray import masking
from emsarray.utils import to_netcdf_with_fixes
from tests.utils import mask_from_strings


Expand Down Expand Up @@ -102,6 +103,26 @@ def test_find_fill_value_masked_and_scaled_int(datasets):
assert_dtype_equal(masking.find_fill_value(data_array), np.int8(-1))


def test_find_fill_value_timedelta_with_missing_value(
datasets: pathlib.Path,
tmp_path: pathlib.Path,
) -> None:
dataset_path = datasets / 'masking/find_fill_value/timedelta_with_missing_value.nc'

missing_value = np.float32(1.e35)
assert_raw_values(
dataset_path, 'var',
np.array([[0, 1], [2, missing_value]], dtype=np.float32))

with xr.open_dataset(dataset_path) as dataset:
data_array = dataset['var']
assert dataset['var'].dtype == np.dtype('timedelta64[ns]')
fill_value = masking.find_fill_value(data_array)
assert np.isnat(fill_value)

to_netcdf_with_fixes(dataset, tmp_path / 'dataset.nc')


def test_calculate_mask_bounds():
mask = xr.Dataset(
data_vars={
Expand Down
57 changes: 39 additions & 18 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,45 +57,66 @@ def test_fix_time_units_for_ems(tmp_path: pathlib.Path):


def test_disable_default_fill_value(tmp_path: pathlib.Path):
foo = xarray.DataArray(
int_var = xarray.DataArray(
data=np.arange(35, dtype=int).reshape(5, 7),
dims=['j', 'i'],
attrs={"Hello": "World"},
)
bar = xarray.DataArray(
data=np.arange(35, dtype=np.float64).reshape(5, 7),
dims=['j', 'i'],
)
baz = xarray.DataArray(
data=np.arange(35, dtype=np.float64).reshape(5, 7),
dims=['j', 'i'],
)
baz.data = np.where(np.tri(5, 7, dtype=bool), baz.data, np.nan)
baz.encoding["_FillValue"] = np.nan

dataset = xarray.Dataset(data_vars={"foo": foo, "bar": bar, "baz": baz})
float_var = xarray.DataArray(
data=np.arange(35, dtype=np.float64).reshape(5, 7),
dims=['j', 'i'])

f_data = np.where(
np.tri(5, 7, dtype=bool),
np.arange(35, dtype=np.float64).reshape(5, 7),
np.nan)
float_with_fill_value_var = xarray.DataArray(data=f_data, dims=['j', 'i'])
float_with_fill_value_var.encoding["_FillValue"] = np.nan

td_data = np.where(
np.tri(5, 7, dtype=bool),
np.arange(35).reshape(5, 7) * np.timedelta64(1, 'D'),
np.timedelta64('nat'))
timedelta_with_missing_value_var = xarray.DataArray(
data=td_data, dims=['j', 'i'])
timedelta_with_missing_value_var.encoding['missing_value'] = np.float64('1e35')
timedelta_with_missing_value_var.encoding['units'] = 'days'

dataset = xarray.Dataset(data_vars={
"int_var": int_var,
"float_var": float_var,
"float_with_fill_value_var": float_with_fill_value_var,
"timedelta_with_missing_value_var": timedelta_with_missing_value_var,
})

# Save to a netCDF4 and then prove that it is bad
dataset.to_netcdf(tmp_path / "bad.nc")
with netCDF4.Dataset(tmp_path / "bad.nc", "r") as nc_dataset:
# This one shouldn't be here because it is an integer datatype. xarray
# does the right thing already in this case.
assert '_FillValue' not in nc_dataset.variables["foo"].ncattrs()
assert '_FillValue' not in nc_dataset.variables["int_var"].ncattrs()
# This one shouldn't be here as we didnt set it, and the array is full!
# This is the problem we are trying to solve
assert np.isnan(nc_dataset.variables["bar"].getncattr("_FillValue"))
assert np.isnan(nc_dataset.variables["float_var"].getncattr("_FillValue"))
# This one is quite alright, we did explicitly set it after all
assert np.isnan(nc_dataset.variables["baz"].getncattr("_FillValue"))
assert np.isnan(nc_dataset.variables["float_with_fill_value_var"].getncattr("_FillValue"))
# This one is incorrect, a `missing_value` attribute has already been set
assert np.isnan(nc_dataset.variables["timedelta_with_missing_value_var"].getncattr("_FillValue"))

utils.disable_default_fill_value(dataset)
dataset.to_netcdf(tmp_path / "good.nc")
with netCDF4.Dataset(tmp_path / "good.nc", "r") as nc_dataset:
# This one should still be unset
assert '_FillValue' not in nc_dataset.variables["foo"].ncattrs()
assert '_FillValue' not in nc_dataset.variables["int_var"].ncattrs()
# This one should now be unset
assert '_FillValue' not in nc_dataset.variables["bar"].ncattrs()
assert '_FillValue' not in nc_dataset.variables["float_var"].ncattrs()
# Make sure this didn't get clobbered
assert np.isnan(nc_dataset.variables["baz"].getncattr("_FillValue"))
assert np.isnan(nc_dataset.variables["float_with_fill_value_var"].getncattr("_FillValue"))
# This one should now be unset
nc_timedelta = nc_dataset.variables["timedelta_with_missing_value_var"]
assert '_FillValue' not in nc_timedelta.ncattrs()
assert nc_timedelta.getncattr('missing_value') == np.float64('1e35')


def test_dataset_like():
Expand Down

0 comments on commit 5fe4e6b

Please sign in to comment.