Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix downscaling large models #145

Merged
merged 14 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ v1.0.2 (unreleased)

Added
-----
- the COG files that are written automatically contain overviews for faster visualization PR #144

Changed
-------
Expand All @@ -21,6 +22,7 @@ Fixed
- writing COG files in `SfincsModel.setup_subgrid` (the COG driver settings were wrong) PR #117
- a constant offset in the `datasets_dep` argument to `SfincsModel.setup_subgrid` and `SfincsModel.setup_dep` was ignored PR #119
- mismatch between gis data and the model grid causing issues while reading the model PR #128
- `utils.downscale_floodmap` now also works for large (rotated) grids PR #145

Deprecated
----------
Expand Down
2 changes: 0 additions & 2 deletions hydromt_sfincs/sfincs.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,6 @@ def setup_subgrid(
manning_land: float = 0.04,
manning_sea: float = 0.02,
rgh_lev_land: float = 0.0,
extrapolate_values: bool = False,
write_dep_tif: bool = False,
write_man_tif: bool = False,
):
Expand Down Expand Up @@ -723,7 +722,6 @@ def setup_subgrid(
rgh_lev_land=rgh_lev_land,
write_dep_tif=write_dep_tif,
write_man_tif=write_man_tif,
extrapolate_values=extrapolate_values,
highres_dir=highres_dir,
logger=self.logger,
)
Expand Down
9 changes: 6 additions & 3 deletions hydromt_sfincs/subgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import xarray as xr
from numba import njit
from rasterio.windows import Window
from scipy import ndimage

from . import utils
from . import workflows

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -199,7 +199,6 @@ def build(
manning_sea: float = 0.02,
rgh_lev_land: float = 0.0,
buffer_cells: int = 0,
extrapolate_values: bool = False,
write_dep_tif: bool = False,
write_man_tif: bool = False,
highres_dir: str = None,
Expand Down Expand Up @@ -524,7 +523,11 @@ def build(
del da_mask_block, da_dep, da_man
gc.collect()

# TODO build COG overviews
# Create COG overviews
if write_dep_tif:
utils.build_overviews(fn=fn_dep_tif, resample_method="average")
if write_man_tif:
utils.build_overviews(fn=fn_man_tif, resample_method="average")

def to_xarray(self, dims, coords):
"""Convert subgrid class to xarray dataset."""
Expand Down
224 changes: 197 additions & 27 deletions hydromt_sfincs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import copy
import io
import logging
from configparser import ConfigParser
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple, Union
Expand All @@ -15,8 +14,9 @@
import hydromt
import numpy as np
import pandas as pd
import pyproj
import rasterio
from rasterio.enums import Resampling
from rasterio.windows import Window
import xarray as xr
from hydromt.io import write_xy
from pyproj.crs.crs import CRS
Expand Down Expand Up @@ -50,6 +50,7 @@
"read_sfincs_his_results",
"downscale_floodmap",
"rotated_grid",
"build_overviews",
]

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -834,21 +835,25 @@ def read_sfincs_his_results(

def downscale_floodmap(
zsmax: xr.DataArray,
dep: xr.DataArray,
dep: Union[Path, str, xr.DataArray],
hmin: float = 0.05,
gdf_mask: gpd.GeoDataFrame = None,
floodmap_fn: Union[Path, str] = None,
reproj_method: str = "nearest",
**kwargs,
) -> xr.Dataset:
):
"""Create a downscaled floodmap for (model) region.

Parameters
----------
zsmax : xr.DataArray
Maximum water level (m). When multiple timesteps provided, maximum over all timesteps is used.
dep : xr.DataArray
High-resolution DEM (m) of model region.
dep : Path, str, xr.DataArray
High-resolution DEM (m) of model region:
* If a Path or str is provided, the DEM is read from disk and the floodmap
is written to disk (recommened for datasets that do not fit in memory.)
* If a xr.DataArray is provided, the floodmap is returned as xr.DataArray
and only written to disk when floodmap_fn is provided.
hmin : float, optional
Minimum water depth (m) to be considered as "flooded", by default 0.05
gdf_mask : gpd.GeoDataFrame, optional
Expand Down Expand Up @@ -876,37 +881,133 @@ def downscale_floodmap(
if timedim:
zsmax = zsmax.max(timedim)

# interpolate zsmax to dep grid
zsmax = zsmax.raster.reproject_like(dep, method=reproj_method)
zsmax = zsmax.raster.mask_nodata() # make sure nodata is nan

# get flood depth
hmax = (zsmax - dep).astype("float32")
hmax.raster.set_nodata(np.nan)

# mask floodmap
hmax = hmax.where(hmax > hmin)
if gdf_mask is not None:
mask = hmax.raster.geometry_mask(gdf_mask, all_touched=True)
hmax = hmax.where(mask)
if isinstance(dep, xr.DataArray):
hmax = _downscale_floodmap_da(
zsmax=zsmax,
dep=dep,
hmin=hmin,
gdf_mask=gdf_mask,
reproj_method=reproj_method,
)

# write floodmap
if floodmap_fn is not None:
if not kwargs: # write COG by default
kwargs = dict(
# write floodmap
if floodmap_fn is not None:
if not kwargs: # write COG by default
kwargs = dict(
driver="GTiff",
tiled=True,
blockxsize=256,
blockysize=256,
compress="deflate",
predictor=2,
profile="COG",
)
hmax.raster.to_raster(floodmap_fn, **kwargs)

# add overviews
build_overviews(fn=floodmap_fn, resample_method="nearest")

hmax.name = "hmax"
hmax.attrs.update({"long_name": "Maximum flood depth", "units": "m"})
return hmax

elif isinstance(dep, (str, Path)):
assert (
floodmap_fn is not None
), "floodmap_fn should be provided when dep is a Path or str."
roeldegoede marked this conversation as resolved.
Show resolved Hide resolved

with rasterio.open(dep) as src:
# Define block size
n1, m1 = src.shape
nrcb = 2000 # nr of cells in a block
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set as argument?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good one, added an nrmax (similar as in setup_subgrid) argument.

nrbn = int(np.ceil(n1 / nrcb)) # nr of blocks in n direction
nrbm = int(np.ceil(m1 / nrcb)) # nr of blocks in m direction

# avoid blocks with width or height of 1
merge_last_col = False
merge_last_row = False
if m1 % nrcb == 1:
nrbm -= 1
merge_last_col = True
if n1 % nrcb == 1:
nrbn -= 1
merge_last_row = True

profile = dict(
driver="GTiff",
width=src.width,
height=src.height,
count=1,
dtype=np.float32,
crs=src.crs,
transform=src.transform,
tiled=True,
blockxsize=256,
blockysize=256,
compress="deflate",
predictor=2,
profile="COG",
nodata=np.nan,
BIGTIFF="YES", # Add the BIGTIFF option here
)
hmax.raster.to_raster(floodmap_fn, **kwargs)

hmax.name = "hmax"
hmax.attrs.update({"long_name": "Maximum flood depth", "units": "m"})
return hmax
with rasterio.open(floodmap_fn, "w", **profile):
pass

## Loop through blocks
for ii in range(nrbm):
bm0 = ii * nrcb # Index of first m in block
bm1 = min(bm0 + nrcb, m1) # last m in block
if merge_last_col and ii == (nrbm - 1):
bm1 += 1

for jj in range(nrbn):
bn0 = jj * nrcb # Index of first n in block
bn1 = min(bn0 + nrcb, n1) # last n in block
if merge_last_row and jj == (nrbn - 1):
bn1 += 1

# Define a window to read a block of data
window = Window(bm0, bn0, bm1 - bm0, bn1 - bn0)

# Read the block of data
block_data = src.read(window=window)

# check for nan-data
if np.all(np.isnan(block_data)):
continue
# Convert row and column indices to pixel coordinates
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to do this now, but since we are working with rasterio (rather than xarray) it would make sense to also directly use the rasterio warp method rather than the raster.reproject method in _downscale_floodmap_da. I think this could provide quite a significant performance improvement. That would affect the whole block of code below this line.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting suggestion but let's postpone that for now. I added a TODO to the code

cols, rows = np.indices((bm1 - bm0, bn1 - bn0))
x_coords, y_coords = src.transform * (cols + bm0, rows + bn0)

# Create xarray DataArray with coordinates
block_dep = xr.DataArray(
block_data.squeeze().transpose(),
dims=("y", "x"),
coords={
"yc": (("y", "x"), y_coords),
"xc": (("y", "x"), x_coords),
},
)
block_dep.raster.set_crs(src.crs)

block_hmax = _downscale_floodmap_da(
zsmax=zsmax,
dep=block_dep,
hmin=hmin,
gdf_mask=gdf_mask,
reproj_method=reproj_method,
)

with rasterio.open(floodmap_fn, "r+") as fm_tif:
fm_tif.write(
np.transpose(block_hmax.values),
window=window,
indexes=1,
)

# add overviews
build_overviews(fn=floodmap_fn, resample_method="nearest")


def rotated_grid(
Expand Down Expand Up @@ -958,3 +1059,72 @@ def _dist(a, b):
nmax = int(np.ceil(axis2 / res))

return x0, y0, mmax, nmax, rot


def build_overviews(fn: Union[str, Path], resample_method: str = "average"):
"""Build overviews for GeoTIFF file.

Overviews are reduced resolution versions of your dataset that can speed up
rendering when you don’t need full resolution. By precomputing the upsampled
pixels, rendering can be significantly faster when zoomed out.

Parameters
----------
fn : str, Path
Path to GeoTIFF file.
method: str
Resampling method, by default "average". Other option is "nearest".
"""

# check if fn is a geotiff file
extensions = [".tif", ".tiff"]
assert any(
fn.endswith(ext) for ext in extensions
), f"File {fn} is not a GeoTIFF file."

# open rasterio dataset
with rasterio.open(fn, "r+") as src:
# create new overviews, resampling with average method
src.build_overviews([2, 4, 8, 16, 32], getattr(Resampling, resample_method))

# update dataset tags
src.update_tags(ns="rio_overview", resampling=resample_method)


def _downscale_floodmap_da(
zsmax: xr.DataArray,
dep: xr.DataArray,
hmin: float = 0.05,
gdf_mask: gpd.GeoDataFrame = None,
reproj_method: str = "nearest",
) -> xr.DataArray:
"""Create a downscaled floodmap for (model) region.

Parameters
----------
zsmax : xr.DataArray
Maximum water level (m). When multiple timesteps provided, maximum over all timesteps is used.
dep : Path, str, xr.DataArray
High-resolution DEM (m) of model region:
hmin : float, optional
Minimum water depth (m) to be considered as "flooded", by default 0.05
gdf_mask : gpd.GeoDataFrame, optional
Geodataframe with polygons to mask floodmap, example containing the landarea, by default None
Note that the area outside the polygons is set to nodata.
"""

# interpolate zsmax to dep grid
zsmax = zsmax.raster.reproject_like(dep, method=reproj_method)
zsmax = zsmax.raster.mask_nodata() # make sure nodata is nan

# get flood depth
hmax = (zsmax - dep).astype("float32")
hmax.raster.set_nodata(np.nan)

# mask floodmap
hmax = hmax.where(hmax > hmin)
if gdf_mask is not None:
mask = hmax.raster.geometry_mask(gdf_mask, all_touched=True)
hmax = hmax.where(mask)

return hmax
Loading