Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow custom dataset names in 'generic_image' reader and fix nodata handling #1560

Merged
merged 20 commits into from
Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 54 additions & 11 deletions satpy/readers/generic_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@

Returns a dataset without calibration. Includes coordinates if
available in the file (eg. geotiff).
If nodata values are present (and rasterio is able to read them), it
will be preserved as attribute ``_FillValue`` in the returned dataset.
In case that nodata values should be used to mask pixels (that have
equal values) with np.nan, it has to be enabled in the reader yaml
file (key ``nodata_handling`` per dataset with value ``"nan_mask"``).
"""

import logging
Expand All @@ -38,6 +43,9 @@
3: ['R', 'G', 'B'],
4: ['R', 'G', 'B', 'A']}

NODATA_HANDLING_FILLVALUE = 'fill_value'
NODATA_HANDLING_NANMASK = 'nan_mask'

logger = logging.getLogger(__name__)


Expand All @@ -56,6 +64,7 @@ def __init__(self, filename, filename_info, filetype_info):
self.finfo['filename'] = self.filename
self.file_content = {}
self.area = None
self.dataset_name = None
self.read()

def read(self):
Expand All @@ -75,14 +84,9 @@ def read(self):
# Rename bands to [R, G, B, A], or a subset of those
data['bands'] = BANDS[data.bands.size]

# Mask data if alpha channel is present
try:
data = mask_image_data(data)
except ValueError as err:
logger.warning(err)

data.attrs = attrs
self.file_content['image'] = data
self.dataset_name = 'image'
djhoese marked this conversation as resolved.
Show resolved Hide resolved
self.file_content[self.dataset_name] = data

def get_area_def(self, dsid):
"""Get area definition of the image."""
Expand All @@ -102,12 +106,29 @@ def end_time(self):

def get_dataset(self, key, info):
"""Get a dataset from the file."""
logger.debug("Reading %s.", key)
return self.file_content[key['name']]
ds_name = self.dataset_name if self.dataset_name else key['name']
djhoese marked this conversation as resolved.
Show resolved Hide resolved
logger.debug("Reading '%s.'", ds_name)
data = self.file_content[ds_name]

# Mask data if necessary
try:
data = _mask_image_data(data, info)
except ValueError as err:
logger.warning(err)

def mask_image_data(data):
"""Mask image data if alpha channel is present."""
data.attrs.update(key.to_dict())
data.attrs.update(info)
return data


def _mask_image_data(data, info):
"""Mask image data if necessary.

Masking is done if alpha channel is present or
dataset 'nodata_handling' is set to 'nan_mask'.
In the latter case even integer data is converted
to float32 and masked with np.nan.
"""
if data.bands.size in (2, 4):
if not np.issubdtype(data.dtype, np.integer):
raise ValueError("Only integer datatypes can be used as a mask.")
Expand All @@ -117,4 +138,26 @@ def mask_image_data(data):
for i in range(data.shape[0])])
data.data = masked_data
data = data.sel(bands=BANDS[data.bands.size - 1])
elif hasattr(data, 'nodatavals') and data.nodatavals:
data = _handle_nodatavals(data, info.get('nodata_handling', NODATA_HANDLING_FILLVALUE))
return data


def _handle_nodatavals(data, nodata_handling):
"""Mask data with np.nan or only set 'attr_FillValue'."""
if nodata_handling == NODATA_HANDLING_NANMASK:
# data converted to float and masked with np.nan
data = data.astype(np.float32)
masked_data = da.stack([da.where(data.data[i, :, :] == nodataval, np.nan, data.data[i, :, :])
for i, nodataval in enumerate(data.nodatavals)])
data.data = masked_data
data.attrs['_FillValue'] = np.nan
elif nodata_handling == NODATA_HANDLING_FILLVALUE:
# keep data as it is but set _FillValue attribute to provided
# nodatavalue (first one as it has to be the same for all bands at least
# in GeoTiff, see GDAL gtiff driver documentation)
fill_value = data.nodatavals[0]
if np.issubdtype(data.dtype, np.integer):
fill_value = int(fill_value)
data.attrs['_FillValue'] = fill_value
return data
113 changes: 103 additions & 10 deletions satpy/tests/reader_tests/test_generic_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def setUp(self):
a__[:10, :10] = 0
a__ = da.from_array(a__, chunks=(50, 50))

r_nan__ = np.random.uniform(0., 1., size=(self.y_size, self.x_size))
r_nan__[:10, :10] = np.nan
r_nan__ = da.from_array(r_nan__, chunks=(50, 50))

ds_l = xr.DataArray(da.stack([r__]), dims=('bands', 'y', 'x'),
attrs={'name': 'test_l',
'start_time': self.date})
Expand All @@ -78,6 +82,12 @@ def setUp(self):
'start_time': self.date})
ds_rgba['bands'] = ['R', 'G', 'B', 'A']

ds_l_nan = xr.DataArray(da.stack([r_nan__]),
dims=('bands', 'y', 'x'),
attrs={'name': 'test_l_nan',
'start_time': self.date})
ds_l_nan['bands'] = ['L']

# Temp dir for the saved images
self.base_dir = tempfile.mkdtemp()

Expand All @@ -91,12 +101,18 @@ def setUp(self):
scn['rgb'].attrs['area'] = self.area_def
scn['rgba'] = ds_rgba
scn['rgba'].attrs['area'] = self.area_def
scn['l_nan'] = ds_l_nan
scn['l_nan'].attrs['area'] = self.area_def

# Save the images. Two images in PNG and two in GeoTIFF
scn.save_dataset('l', os.path.join(self.base_dir, 'test_l.png'), writer='simple_image')
scn.save_dataset('la', os.path.join(self.base_dir, '20180101_0000_test_la.png'), writer='simple_image')
scn.save_dataset('rgb', os.path.join(self.base_dir, '20180101_0000_test_rgb.tif'), writer='geotiff')
scn.save_dataset('rgba', os.path.join(self.base_dir, 'test_rgba.tif'), writer='geotiff')
scn.save_dataset('l_nan', os.path.join(self.base_dir, 'test_l_nan_fillvalue.tif'),
writer='geotiff', fill_value=0)
scn.save_dataset('l_nan', os.path.join(self.base_dir, 'test_l_nan_nofillvalue.tif'),
writer='geotiff')

self.scn = scn

Expand Down Expand Up @@ -133,7 +149,7 @@ def test_png_scene(self):
self.assertEqual(np.sum(np.isnan(data)), 100)

def test_geotiff_scene(self):
"""Test reading PNG images via satpy.Scene()."""
"""Test reading TIFF images via satpy.Scene()."""
from satpy import Scene

fname = os.path.join(self.base_dir, '20180101_0000_test_rgb.tif')
Expand All @@ -154,10 +170,25 @@ def test_geotiff_scene(self):
self.assertEqual(scn.attrs['end_time'], None)
self.assertEqual(scn['image'].area, self.area_def)

def test_geotiff_scene_nan(self):
"""Test reading TIFF images originally containing NaN values via satpy.Scene()."""
from satpy import Scene

fname = os.path.join(self.base_dir, 'test_l_nan_fillvalue.tif')
scn = Scene(reader='generic_image', filenames=[fname])
scn.load(['image'])
self.assertEqual(scn['image'].shape, (1, self.y_size, self.x_size))
self.assertEqual(np.sum(scn['image'].data[0][:10, :10].compute()), 0)

fname = os.path.join(self.base_dir, 'test_l_nan_nofillvalue.tif')
scn = Scene(reader='generic_image', filenames=[fname])
scn.load(['image'])
self.assertEqual(scn['image'].shape, (1, self.y_size, self.x_size))
self.assertTrue(np.all(np.isnan(scn['image'].data[0][:10, :10].compute())))

Comment on lines +177 to +188
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind separating these out into their own test functions. I think we've generally switched to not combining a ton of test cases into one function whenever possible. It makes it easier to debug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

def test_GenericImageFileHandler(self):
"""Test direct use of the reader."""
from satpy.readers.generic_image import GenericImageFileHandler
from satpy.readers.generic_image import mask_image_data

fname = os.path.join(self.base_dir, 'test_rgba.tif')
fname_info = {'start_time': self.date}
Expand All @@ -174,16 +205,78 @@ def test_GenericImageFileHandler(self):
self.assertEqual(reader.start_time, self.date)
self.assertEqual(reader.end_time, self.date)

dataset = reader.get_dataset(foo, None)
dataset = reader.get_dataset(foo, {})
self.assertTrue(isinstance(dataset, xr.DataArray))
self.assertTrue('crs' in dataset.attrs)
self.assertTrue('transform' in dataset.attrs)
self.assertIn('crs', dataset.attrs)
self.assertIn('transform', dataset.attrs)
self.assertTrue(np.all(np.isnan(dataset.data[:, :10, :10].compute())))

# Test masking of floats
def test_GenericImageFileHandler_masking_only_integer(self):
"""Test direct use of the reader."""
from satpy.readers.generic_image import GenericImageFileHandler

class FakeGenericImageFileHandler(GenericImageFileHandler):

def __init__(self, filename, filename_info, filetype_info, file_content, **kwargs):
"""Get fake file content from 'get_test_content'."""
super(GenericImageFileHandler, self).__init__(filename, filename_info, filetype_info)
self.file_content = file_content
self.dataset_name = None
self.file_content.update(kwargs)

data = self.scn['rgba']
self.assertRaises(ValueError, mask_image_data, data / 255.)

# do nothing if not integer
float_data = data / 255.
reader = FakeGenericImageFileHandler("dummy", {}, {}, {"image": float_data})
self.assertIs(reader.get_dataset(make_dataid(name='image'), {}), float_data)

# masking if integer
data = data.astype(np.uint32)
self.assertTrue(data.bands.size == 4)
data = mask_image_data(data)
self.assertTrue(data.bands.size == 3)
self.assertEqual(data.bands.size, 4)
reader = FakeGenericImageFileHandler("dummy", {}, {}, {"image": data})
ret_data = reader.get_dataset(make_dataid(name='image'), {})
self.assertEqual(ret_data.bands.size, 3)

def test_GenericImageFileHandler_datasetid(self):
"""Test direct use of the reader."""
from satpy.readers.generic_image import GenericImageFileHandler

fname = os.path.join(self.base_dir, 'test_rgba.tif')
fname_info = {'start_time': self.date}
ftype_info = {}
reader = GenericImageFileHandler(fname, fname_info, ftype_info)

foo = make_dataid(name='image-custom')
self.assertTrue(reader.file_content, 'file_content should be set')
dataset = reader.get_dataset(foo, {})
self.assertTrue(isinstance(dataset, xr.DataArray), 'dataset should be a xr.DataArray')

def test_GenericImageFileHandler_nodata(self):
"""Test nodata handling with direct use of the reader."""
from satpy.readers.generic_image import GenericImageFileHandler

fname = os.path.join(self.base_dir, 'test_l_nan_fillvalue.tif')
fname_info = {'start_time': self.date}
ftype_info = {}
reader = GenericImageFileHandler(fname, fname_info, ftype_info)

foo = make_dataid(name='image-custom')
self.assertTrue(reader.file_content, 'file_content should be set')
info = {'nodata_handling': 'nan_mask'}
dataset = reader.get_dataset(foo, info)
self.assertTrue(isinstance(dataset, xr.DataArray), 'dataset should be a xr.DataArray')
self.assertTrue(np.all(np.isnan(dataset.data[0][:10, :10].compute())), 'values should be np.nan')
self.assertTrue(np.isnan(dataset.attrs['_FillValue']), '_FillValue should be np.nan')

info = {'nodata_handling': 'fill_value'}
dataset = reader.get_dataset(foo, info)
self.assertTrue(isinstance(dataset, xr.DataArray), 'dataset should be a xr.DataArray')
self.assertEqual(np.sum(dataset.data[0][:10, :10].compute()), 0)
self.assertEqual(dataset.attrs['_FillValue'], 0)

# default same as 'nodata_handling': 'fill_value'
dataset = reader.get_dataset(foo, {})
self.assertTrue(isinstance(dataset, xr.DataArray), 'dataset should be a xr.DataArray')
self.assertEqual(np.sum(dataset.data[0][:10, :10].compute()), 0)
self.assertEqual(dataset.attrs['_FillValue'], 0)