Skip to content

Commit

Permalink
create gdf meta data as named tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
ValentinGebhart committed Jul 10, 2024
1 parent 24ba039 commit d79b670
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 218 deletions.
20 changes: 13 additions & 7 deletions climada/hazard/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import climada.util.constants as u_const
import climada.util.coordinates as u_coord
import climada.util.dates_times as u_dt
from climada.util.plot import GdfMeta


LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -487,18 +488,23 @@ def local_return_period(self, threshold_intensities):
end_col = min(start_col + block_size, num_cen)
return_periods[:, start_col:end_col] = self._loc_return_period(
threshold_intensities,
self.intensity[:, start_col:end_col].toarray())
self.intensity[:, start_col:end_col].toarray()
)

# create the output GeoDataFrame
gdf = gpd.GeoDataFrame(geometry = self.centroids.gdf['geometry'], crs = self.centroids.gdf.crs)
col_names = [f'{tresh_inten}' for tresh_inten in threshold_intensities]
gdf.columns.name = (('name', 'Return Period'),
('unit', 'Years'),
('col_name', 'Threshold Intensity'),
('col_unit', self.units))
gdf[col_names] = return_periods.T

return gdf
#create gdf meta data
gdf_meta = GdfMeta(
name = 'Return Period',
unit = 'Years',
col_name = 'Threshold Intensity',
col_unit = self.units
)

return gdf, gdf_meta

def get_event_id(self, event_name):
"""Get an event id from its name. Several events might have the same
Expand Down Expand Up @@ -682,7 +688,7 @@ def _loc_return_period(self, threshold_intensities, inten):
sort_pos = np.argsort(inten, axis=0)[::-1, :]
inten_sort = inten[sort_pos, np.arange(inten.shape[1])]
freq_sort = self.frequency[sort_pos]
np.cumsum(freq_sort, axis=0, out=freq_sort)
freq_sort = np.cumsum(freq_sort, axis=0)
return_periods = np.zeros((len(threshold_intensities), inten.shape[1]))

for cen_idx in range(inten.shape[1]):
Expand Down
122 changes: 0 additions & 122 deletions climada/hazard/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,128 +1039,6 @@ def write_hdf5(self, file_name, todense=False):
)
self.centroids.write_hdf5(file_name, mode='a')

def write_raster_local_exceedance_inten(self, return_periods, filename):
"""
Generates exceedance intensity data for specified return periods and
saves it into a GeoTIFF file.
Parameters
----------
return_periods : np.array or list
Array or list of return periods (in years) for which to calculate
and store exceedance intensities.
filename : str
Path and name of the file to write in tif format.
"""
inten_stats = self.local_exceedance_inten(return_periods=return_periods)
num_bands = inten_stats.shape[0]

if not self.centroids.get_meta():
raise ValueError("centroids.get_meta() is required but not set.")

### this code is to replace pixel_geom = self.centroids.calc_pixels_polygons()
if abs(abs(self.centroids.get_meta()['transform'].a) -
abs(self.centroids.get_meta()['transform'].e)) > 1.0e-5:
raise ValueError('Area can not be computed for not squared pixels.')
pixel_geom = self.centroids.geometry.buffer(self.centroids.get_meta()['transform'].a / 2).envelope
###
profile = self.centroids.get_meta().copy()
profile.update(driver='GTiff', dtype='float32', count=num_bands)

with rasterio.open(filename, 'w', **profile) as dst:
LOGGER.info('Writing %s', filename)
for band in range(num_bands):
raster = rasterio.features.rasterize(
[(x, val) for (x, val) in zip(pixel_geom, inten_stats[band].reshape(-1))],
out_shape=(profile['height'], profile['width']),
transform=profile['transform'], fill=0,
all_touched=True, dtype=profile['dtype'])
dst.write(raster, band + 1)

band_name = f"Exceedance intensity for RP {return_periods[band]} years"
dst.set_band_description(band + 1, band_name)

def write_raster_local_return_periods(self, threshold_intensities, filename, output_resolution=None):
"""Write local return periods map as GeoTIFF file.
Parameters
----------
threshold_intensities: np.array
Hazard intensities to consider for calculating return periods.
filename: str
File name to write in tif format.
output_resolution: int
If not None, the data is rasterized to this resolution.
Default is None (resolution is estimated from the data).
"""

# Calculate the local return periods
variable = self.local_return_period(threshold_intensities)

# Obtain the meta information for the raster file
meta = self.centroids.get_meta(resolution=output_resolution)
meta.update(driver='GTiff', dtype=rasterio.float32, count=len(threshold_intensities))
res = meta["transform"][0] # resolution from lon coordinates

if meta['height'] * meta['width'] == self.centroids.size:
# centroids already in raster format
u_coord.write_raster(filename, variable, meta)
else:
geometry = self.centroids.get_pixel_shapes(res=res)
with rasterio.open(filename, 'w', **meta) as dst:
LOGGER.info('Writing %s', filename)
for i_ev in range(len(threshold_intensities)):
raster = rasterio.features.rasterize(
(
(geom, value)
for geom, value
in zip(geometry, variable[i_ev].flatten())
),
out_shape=(meta['height'], meta['width']),
transform=meta['transform'],
fill=0,
all_touched=True,
dtype=meta['dtype'],
)
dst.write(raster.astype(meta['dtype']), i_ev + 1)

# Set the band description
band_name = f"RP of intensity {threshold_intensities[i_ev]} {self.units}"
dst.set_band_description(i_ev + 1, band_name)


def write_netcdf_local_return_periods(self, threshold_intensities, filename):
"""Generates local return period data and saves it into a NetCDF file.
Parameters
----------
threshold_intensities: np.array
Hazard intensities to consider for calculating return periods.
filename: str
Path and name of the file to write the NetCDF data.
"""
return_periods = self.local_return_period(threshold_intensities)
coords = self.centroids.coord

with nc.Dataset(filename, 'w', format='NETCDF4') as dataset:
centroids_dim = dataset.createDimension('centroids', coords.shape[0])

latitudes = dataset.createVariable('latitude', 'f4', ('centroids',))
longitudes = dataset.createVariable('longitude', 'f4', ('centroids',))
latitudes[:] = coords[:, 0]
longitudes[:] = coords[:, 1]
latitudes.units = 'degrees_north'
longitudes.units = 'degrees_east'

for i, intensity in enumerate(threshold_intensities):
dataset_name = f'return_period_intensity_{int(intensity)}'
return_period_var = dataset.createVariable(dataset_name, 'f4', ('centroids',))
return_period_var[:] = return_periods[i, :]
return_period_var.units = 'years'
return_period_var.description = f'Local return period for hazard intensity {intensity} {self.units}'

dataset.description = 'Local return period data for given hazard intensities'

def read_hdf5(self, *args, **kwargs):
"""This function is deprecated, use Hazard.from_hdf5."""
LOGGER.warning("The use of Hazard.read_hdf5 is deprecated."
Expand Down
2 changes: 1 addition & 1 deletion climada/hazard/test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,7 @@ def test_local_return_period(self):
])
haz.frequency = np.full(4, 1.)
threshold_intensities = np.array([1., 2., 4.])
return_stats = haz.local_return_period(threshold_intensities)
return_stats, _ = haz.local_return_period(threshold_intensities)
np.testing.assert_allclose(
return_stats[return_stats.columns[1:]].values.T,
np.array([
Expand Down
69 changes: 0 additions & 69 deletions climada/hazard/test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,75 +686,6 @@ class CustomID:
self.assertTrue(np.array_equal(hazard.date, hazard_read.date))
self.assertTrue(np.array_equal(hazard_read.event_id, np.array([]))) # Empty array

# class TestWriteExceedAndRP(unittest.TestCase):
# """Test writing raster and netCDF files from exceedance intensitiy and return period maps"""

# def test_write_raster_exceed_inten(self):
# """Test write TIFF file from local exceedance intensity"""
# self.temp_dir = TemporaryDirectory()
# self.test_file_path = os.path.join(self.temp_dir.name, 'test_file.tif')

# haz = Hazard.from_hdf5(HAZ_TEST_TC)
# haz.write_raster_local_exceedance_inten([10, 20, 50], filename = self.test_file_path)
# raster = rasterio.open(self.test_file_path)
# dataarray = np.array([ raster.read(i + 1) for i in range(raster.count)])

# np.testing.assert_array_almost_equal(
# np.transpose(np.flip(dataarray, axis = 1), axes= [0, 2, 1]),
# haz.local_exceedance_inten([10, 20, 50]).reshape((3, 10, 10)),
# decimal=4
# )
# self.temp_dir.cleanup()

# def test_write_raster_local_return_periods(self):
# """Test write TIFF file from local return periods intensity"""
# self.temp_dir = TemporaryDirectory()
# self.test_file_path = os.path.join(self.temp_dir.name, 'test_file.tif')

# haz = Hazard.from_hdf5(HAZ_TEST_TC)
# haz.write_raster_local_return_periods([10., 20., 30.], filename = self.test_file_path)
# raster = rasterio.open(self.test_file_path)
# dataarray = np.array([ raster.read(i + 1) for i in range(raster.count)])

# np.testing.assert_array_almost_equal(
# dataarray,
# haz.local_return_period([10., 20., 30.]).reshape((3, 10, 10)),
# decimal=4
# )
# self.temp_dir.cleanup()

# def test_write_raster_local_return_periods_not_raster(self):
# """Test write TIFF file from local return periods intensity"""
# self.temp_dir = TemporaryDirectory()
# self.test_file_path = os.path.join(self.temp_dir.name, 'test_file.tif')

# haz = dummy_hazard()
# haz.write_raster_local_return_periods([.1, 1.], filename = self.test_file_path)
# raster = rasterio.open(self.test_file_path)
# dataarray = np.array([ raster.read(i + 1) for i in range(raster.count)])

# np.testing.assert_array_almost_equal(
# dataarray.max(axis=1),
# haz.local_return_period([.1, 1.]),
# decimal=4
# )
# self.temp_dir.cleanup()

# def test_write_netcdf_local_return_periods(self):
# """Test write netCDF file from local return periods intensity"""
# self.temp_dir = TemporaryDirectory()
# self.test_file_path = os.path.join(self.temp_dir.name, 'test_file.nc')

# haz = Hazard.from_hdf5(HAZ_TEST_TC)
# haz.write_netcdf_local_return_periods([10., 20., 30.], filename = self.test_file_path)
# dataset = xr.load_dataset(self.test_file_path)

# np.testing.assert_array_almost_equal(
# dataset.to_array().data[2:,:],
# haz.local_return_period([10., 20., 30.]),
# decimal=4
# )
# self.temp_dir.cleanup()

# Execute Tests
if __name__ == "__main__":
Expand Down
44 changes: 30 additions & 14 deletions climada/util/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from rasterio.crs import CRS
import requests
from geopandas import GeoDataFrame
from collections import namedtuple

from climada.util.constants import CMAP_EXPOSURES, CMAP_CAT, CMAP_RASTER
from climada.util.files_handler import to_list
Expand Down Expand Up @@ -876,14 +877,20 @@ def multibar_plot(ax, data, colors=None, total_width=0.8, single_width=1,
if legend:
ax.legend(bars, data.keys())

# Create GdfMeta class (as a named tuple) for GeoDataFrame meta data
GdfMeta = namedtuple('GdfMeta', ['name', 'unit', 'col_name', 'col_unit'])

def subplots_from_gdf(gdf, smooth=True, axis=None, figsize=(9, 13), adapt_fontsize=True, **kwargs):
def subplots_from_gdf(gdf: GeoDataFrame, gdf_meta: GdfMeta = None, smooth=True, axis=None, figsize=(9, 13), adapt_fontsize=True, **kwargs):
"""Plot hazard local return periods for given hazard intensities.
Parameters
----------
gdf: gpd.GeoDataFrame
return periods per threshold intensity
gdf_meta:
climada.util.plot.GdfMeta
gdf meta data in named tuple with attributes 'name' (quantity in gdf), 'unit', (unit thereof)
'col_name' (quantity in column labels), 'col_unit' (unit thereof)
smooth: bool, optional
Smooth plot to plot.RESOLUTION x plot.RESOLUTION. Default is True
axis: matplotlib.axes._subplots.AxesSubplot, optional
Expand All @@ -901,32 +908,41 @@ def subplots_from_gdf(gdf, smooth=True, axis=None, figsize=(9, 13), adapt_fontsi
axis: matplotlib.axes._subplots.AxesSubplot
Matplotlib axis with the plot.
"""
# check if inputs are correct types
if not isinstance(gdf, GeoDataFrame):
raise ValueError("gdf is not a GeoDataFrame")
gdf = gdf[['geometry', *[col for col in gdf.columns if col != 'geometry']]]
try:
meta = {key: val for key, val in gdf.columns.name}
colbar_name = f"{meta['name']} ({meta['unit']})"
title_subplots = [f"{meta['col_name']}: {thres_inten} {meta['col_unit']}"

# read meta data for fig and axis labels
if not isinstance(gdf_meta, GdfMeta):
#warnings.warn("gdf_meta variable is not of type climada.util.plot.GdfMeta. Figure and axis labels will be missing.")
print("gdf_meta variable is not of type climada.util.plot.GdfMeta. Figure and axis labels will be missing.")
colbar_name, title_subplots = None, [f"{col}" for col in gdf.columns[1:]]
else:
colbar_name = f"{gdf_meta.name} ({gdf_meta.unit})"
title_subplots = [f"{gdf_meta.col_name}: {thres_inten} {gdf_meta.col_unit}"
for thres_inten in gdf.columns[1:]]

# change default plot kwargs if plotting return periods
if meta['name'] == 'Return Period':
if 'camp' not in kwargs.keys():
if gdf_meta.name == 'Return Period':
if 'cmap' not in kwargs.keys():
kwargs.update({'cmap': 'viridis_r'})
if 'norm' not in kwargs.keys():
kwargs.update(
{'norm': mpl.colors.LogNorm(vmin=gdf.values[:,1:].min(), vmax=gdf.values[:,1:].max()),
'vmin': None, 'vmax': None}
)
except:
colbar_name, title_subplots = None, [f"{col}" for col in gdf.columns[1:]]


axis = geo_im_from_array(
gdf.values[:,1:].T,
gdf.geometry.get_coordinates().values[:,::-1],
colbar_name, title_subplots,
smooth=smooth, axes=axis,
figsize=figsize, adapt_fontsize=adapt_fontsize, **kwargs)
colbar_name,
title_subplots,
smooth=smooth,
axes=axis,
figsize=figsize,
adapt_fontsize=adapt_fontsize,
**kwargs
)

return axis
12 changes: 7 additions & 5 deletions climada/util/test/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,13 @@ def test_subplots_from_gdf(self):
columns = ('10.0', '20.0')
)
return_periods['geometry'] = (Point(45., 26.), Point(46., 26.), Point(45., 27.), Point(46., 27.))
return_periods.columns.name = (('name', 'Return Period'),
('unit', 'Years'),
('col_name', 'Threshold Intensity'),
('col_unit', 'm/s'))
(axis1, axis2) = u_plot.subplots_from_gdf(return_periods)
gdf_meta = u_plot.GdfMeta(
name = 'Return Period',
unit = 'Years',
col_name = 'Threshold Intensity',
col_unit = 'm/s'
)
(axis1, axis2) = u_plot.subplots_from_gdf(return_periods, gdf_meta)
self.assertEqual('Threshold Intensity: 10.0 m/s', axis1.get_title())
self.assertEqual('Threshold Intensity: 20.0 m/s', axis2.get_title())
plt.close()
Expand Down

0 comments on commit d79b670

Please sign in to comment.