diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index cc9ab8f1..4fec4313 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -7,12 +7,13 @@ on: pull_request: jobs: - black: + lint: + name: Linting Suite runs-on: ubuntu-latest strategy: matrix: tox-env: - - black + - lint steps: - name: Cancel previous runs uses: styfle/cancel-workflow-action@0.11.0 @@ -28,8 +29,8 @@ jobs: run: tox -e ${{ matrix.tox-env }} pypi: - name: Pip with Python${{ matrix.python-version }} - needs: black + name: Python${{ matrix.python-version }} (PyPI + Tox) + needs: lint runs-on: ubuntu-latest strategy: matrix: @@ -61,17 +62,19 @@ jobs: COVERALLS_SERVICE_NAME: github conda: - name: Anaconda Build with Python${{ matrix.python-version }} (upstream=${{ matrix.upstream }}) - needs: black + name: Python${{ matrix.python-version }} (Anaconda, upstream=${{ matrix.upstream }}) + needs: lint runs-on: ubuntu-latest strategy: fail-fast: false matrix: include: - python-version: "3.9" - upstream: true + upstream: false - python-version: "3.10" upstream: false + - python-version: "3.11" + upstream: false defaults: run: shell: bash -l {0} diff --git a/.github/workflows/upstream.yml b/.github/workflows/upstream.yml new file mode 100644 index 00000000..620f89c8 --- /dev/null +++ b/.github/workflows/upstream.yml @@ -0,0 +1,81 @@ +name: Test Upstream Dependencies + +on: + push: + branches: + - master + paths-ignore: + - HISTORY.rst + - README.rst + - pyproject.toml + - setup.cfg + - clisops/__init__.py + schedule: + - cron: "0 0 * * *" # Daily “At 00:00” UTC + workflow_dispatch: # allows you to trigger the workflow run manually + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + upstream-dev: + name: test-upstream-dev (Python${{ matrix.python-version }}) + runs-on: ubuntu-latest + if: | + (github.event_name == 'schedule') || + (github.event_name == 'workflow_dispatch') || + (github.event_name == 'push') + strategy: + fail-fast: false + matrix: + python-version: + - "3.10" + defaults: + run: + shell: bash -l {0} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Fetch all history for all branches and tags. + - name: Setup Conda (Micromamba) with Python${{ matrix.python-version }} + uses: mamba-org/setup-micromamba@v1 + with: + cache-downloads: true + cache-environment: true + environment-file: environment.yml + create-args: >- + conda + python=${{ matrix.python-version }} + pytest-reportlog + - name: Conda and Mamba versions + run: | + conda --version + echo "micromamba: $(micromamba --version)" + - name: Install upstream versions + run: | + python -m pip install -r requirements_upstream.txt + - name: Install CLISOPS + run: | + python -m pip install --no-user --editable ".[dev]" + - name: Install upstream versions + run: | + python -m pip install -r requirements_upstream.txt + - name: Check versions + run: | + conda list + python -m pip check || true + - name: Run Tests + if: success() + id: status + run: | + python -m pytest --durations=10 --cov=clisops --cov-report=term-missing --report-log output-${{ matrix.python-version }}-log.jsonl + - name: Generate and publish the report + if: | + failure() + && steps.status.outcome == 'failure' + && github.event_name == 'schedule' + && github.repository_owner == 'roocs' + uses: xarray-contrib/issue-from-pytest-log@v1 + with: + log-path: output-${{ matrix.python-version }}-log.jsonl diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1a86746e..a61584c8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,51 +1,55 @@ default_language_version: - python: python3 + python: python3 repos: -- repo: https://github.com/asottile/pyupgrade + - repo: https://github.com/asottile/pyupgrade rev: v3.15.0 hooks: - - id: pyupgrade + - id: pyupgrade args: [ '--py38-plus' ] -- repo: https://github.com/pre-commit/pre-commit-hooks + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - - id: trailing-whitespace + - id: trailing-whitespace exclude: setup.cfg - - id: end-of-file-fixer + - id: end-of-file-fixer exclude: setup.cfg - - id: check-yaml - - id: debug-statements - - id: mixed-line-ending -- repo: https://github.com/psf/black-pre-commit-mirror + - id: check-toml + - id: check-yaml + - id: debug-statements + - id: mixed-line-ending + - repo: https://github.com/psf/black-pre-commit-mirror rev: 23.11.0 hooks: - id: black - args: ["--target-version", "py38"] -- repo: https://github.com/pycqa/flake8 + - repo: https://github.com/pycqa/flake8 rev: 6.1.0 hooks: - - id: flake8 - args: ['--config=setup.cfg'] -- repo: https://github.com/PyCQA/isort + - id: flake8 + args: [ '--config=setup.cfg' ] + - repo: https://github.com/PyCQA/isort rev: 5.12.0 hooks: - - id: isort - args: ['--profile', 'black'] + - id: isort #- repo: https://github.com/pycqa/pydocstyle # rev: 6.1.1 # hooks: # - id: pydocstyle # args: ["--convention=numpy"] -- repo: https://github.com/kynan/nbstripout + - repo: https://github.com/kynan/nbstripout rev: 0.6.1 hooks: - - id: nbstripout + - id: nbstripout files: ".ipynb" -- repo: meta + - repo: https://github.com/python-jsonschema/check-jsonschema + rev: 0.27.1 hooks: - - id: check-hooks-apply - - id: check-useless-excludes + - id: check-github-workflows + - id: check-readthedocs + - repo: meta + hooks: + - id: check-hooks-apply + - id: check-useless-excludes ci: autofix_commit_msg: | @@ -55,5 +59,5 @@ ci: autoupdate_branch: '' autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' autoupdate_schedule: quarterly - skip: [ ] + skip: [] submodules: false diff --git a/.readthedocs.yml b/.readthedocs.yml index 10041d92..1d3914f1 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -20,7 +20,7 @@ build: python: "mambaforge-22.9" conda: - environment: environment.yml + environment: docs/environment.yml # Optionally set the version of Python and requirements required to build your docs python: diff --git a/HISTORY.rst b/HISTORY.rst index 00f22668..8b9071fc 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,10 +4,22 @@ Version History v0.12.0 (unreleased) -------------------- +New Features +^^^^^^^^^^^^ +* ``clisops.ops.regrid``, ``clisops.core.regrid``, ``clisops.core.Weights`` and ``clisops.core.Grid`` added (#TBA). Allowing the remapping of geospatial data on various grids by applying the `xESMF `_ regridder. + Bug Fixes ^^^^^^^^^ * Calling `subset_shape()` with a `locstream case` (#288) returned all coordinates inside `inner_mask` which is equivalent to the bounding box of the polygon, not the area inside the polygon. Fixed by defining the `inner_mask` in `subset_shape()` for the locstream case. (#292). +Other Changes +^^^^^^^^^^^^^ +* Extending the removal of redundant _FillValue attributes to all data variables and coordinates. +* Extending the removal of redundant coordinates in the coordinates variable attribute from bounds to all data variables. +* GitHub Workflows for upstream dependencies are now examined a schedule or via `workflow_dispatch` (#243). +* `black` steps are now called `lint` for clarity/inclusiveness of other linting hooks. (#243). +* pre-commit hooks now include checks for TOML files, and for ReadTheDocs and GitHub Actions configuration files. (#243). + v0.11.0 (2023-08-22) -------------------- diff --git a/README.rst b/README.rst index f2a6e925..dfa97661 100644 --- a/README.rst +++ b/README.rst @@ -30,6 +30,24 @@ The package provides the following operations: * average * regrid +Online Demo +----------- + +.. + todo: Links have to be adjusted to the master or respective branch! + +You can try clisops online using Binder (just click on the binder link below), +or view the notebooks on NBViewer. + +.. image:: https://mybinder.org/badge_logo.svg + :target: https://mybinder.org/v2/gh/roocs/clisops/master?filepath=notebooks + :alt: Binder Launcher + +.. image:: https://raw.githubusercontent.com/jupyter/design/master/logos/Badges/nbviewer_badge.svg + :target: https://nbviewer.jupyter.org/github/roocs/clisops/tree/master/notebooks/ + :alt: NBViewer + :height: 20 + Credits ------- diff --git a/binder/environment.yml b/binder/environment.yml new file mode 100644 index 00000000..14e35a36 --- /dev/null +++ b/binder/environment.yml @@ -0,0 +1,37 @@ +name: clisops +channels: + - conda-forge +dependencies: + - python >=3.8,<3.12 + - flit + - bottleneck >=1.3.1 + - cf_xarray >=0.8.5 + - cftime >=1.4.1 + - dask >=2.6.0 + - gdal >=3.0 + - geopandas >=0.11 + - loguru >=0.5.3 + - netCDF4 >=1.4 + - numpy >=1.16 + - packaging + - pandas >=1.0.3 + - pooch + - poppler >=0.67 + - pyproj >=3.3.0 + - requests >=2.0 + - roocs-grids>=0.1.2 + - roocs-utils >=0.6.4,<0.7 + - shapely >=1.9 + - xarray >=0.21,<2023.3.0 # https://github.com/pydata/xarray/issues/7794 + - xesmf >=0.8.2 + - cartopy>=0.20.2 + - jupyterlab + - GitPython + - matplotlib>=3.5.2 + # Upstream + - pip + - pip: + - psy-maps + - clisops +# - cf-xarray @ git+https://github.com/xarray-contrib/cf-xarray/@main#egg=cf-xarray +# - roocs-utils @ git+https://github.com/roocs/roocs-utils.git@master#egg=roocs-utils diff --git a/binder/postBuild b/binder/postBuild index 4fe977f3..7cbed1f8 100644 --- a/binder/postBuild +++ b/binder/postBuild @@ -1,2 +1,3 @@ -pip install matplotlib -pip install . +set -e + +git clone https://github.com/roocs/mini-esgf-data ~/.mini-esgf-data diff --git a/clisops/core/__init__.py b/clisops/core/__init__.py index b7a2fcc5..bb5bd712 100644 --- a/clisops/core/__init__.py +++ b/clisops/core/__init__.py @@ -9,3 +9,5 @@ subset_time_by_components, subset_time_by_values, ) + +from .regrid import Grid, Weights, regrid, weights_cache_init, weights_cache_flush diff --git a/clisops/core/regrid.py b/clisops/core/regrid.py new file mode 100644 index 00000000..74f88879 --- /dev/null +++ b/clisops/core/regrid.py @@ -0,0 +1,1865 @@ +"""Regrid module.""" +from __future__ import annotations + +import functools +import json +import os +import warnings +from collections import ChainMap, OrderedDict +from glob import glob +from hashlib import md5 +from math import sqrt +from pathlib import Path + +import cf_xarray # noqa +import numpy as np +import roocs_grids +import xarray as xr +from packaging.version import Version +from roocs_utils.exceptions import InvalidParameterValue + +import clisops.utils.dataset_utils as clidu +from clisops import CONFIG +from clisops import __version__ as __clisops_version__ +from clisops.utils.common import check_dir, require_module +from clisops.utils.output_utils import FileLock, create_lock + +# Try importing xesmf and set to None if not found at correct version +# If set to None, the `require_module` decorator will throw an exception +XESMF_MINIMUM_VERSION = "0.7.0" +try: + import xesmf as xe + + if Version(xe.__version__) < Version(XESMF_MINIMUM_VERSION): + raise ValueError() +except (ModuleNotFoundError, ValueError): + xe = None + + +# Read coordinate variable precision from the clisops configuration (roocs.ini) +# All horizontal coordinate variables will be rounded to this precision +coord_precision_hor = int(CONFIG["clisops:coordinate_precision"]["hor_coord_decimals"]) + +# Check if xESMF module is imported - decorator, used below +require_xesmf = functools.partial( + require_module, module=xe, module_name="xESMF", min_version=XESMF_MINIMUM_VERSION +) + + +def weights_cache_init(weights_dir: str | Path) -> None: + """Initialize global variable `weights_dir` as used by the Weights class. + + Parameters + ---------- + weights_dir : str or Path + Directory name to initialize the local weights cache in. + Will be created if it does not exist. + Per default, this function is called upon import with weights_dir as defined in roocs.ini. + + Returns + ------- + None + """ + # Overwrite CONFIG entry with new value + CONFIG["clisops:grid_weights"]["local_weights_dir"] = str(weights_dir) + + # Create directory tree if required + if not os.path.isdir(weights_dir): + os.makedirs(weights_dir) + + +# Initialize weights cache as defined in the clisops configuration (roocs.ini) +weights_cache_init(CONFIG["clisops:grid_weights"]["local_weights_dir"]) + +# Ensure local weight storage directory exists - decorator, used below +check_weights_dir = functools.partial( + check_dir, dr=CONFIG["clisops:grid_weights"]["local_weights_dir"] +) + + +def weights_cache_flush( + weights_dir_init: str | Path | None = "", + dryrun: bool | None = False, + verbose: bool | None = False, +) -> None: + """Flush and reinitialize the local weights cache. + + Parameters + ---------- + weights_dir_init : str, optional + Directory name to reinitialize the local weights cache in. + Will be created if it does not exist. + The default is CONFIG["clisops:grid_weights"]["local_weights_dir"] as defined in roocs.ini + (or as redefined by a manual weights_cache_init call). + dryrun : bool, optional + If True, it will only print all files that would get deleted. The default is False. + verbose : bool, optional + If True, and dryrun is False, will print all files that are getting deleted. + The default is False. + + Returns + ------- + None + """ + # Read weights_dir from CONFIG + weights_dir = CONFIG["clisops:grid_weights"]["local_weights_dir"] + + if dryrun: + print(f"Flushing the clisops weights cache ('{weights_dir}') would remove:") + elif verbose: + print(f"Flushing the clisops weights cache ('{weights_dir}'). Removing ...") + + # Find and delete/report weight files, grid files and the json files containing the metadata + if os.path.isdir(weights_dir): + flist_weights = glob(f"{weights_dir}/weights_{'?'*32}_{'?'*32}_*.nc") + flist_meta = glob(f"{weights_dir}/weights_{'?'*32}_{'?'*32}_*.json") + flist_grids = glob(f"{weights_dir}/grid_{'?'*32}.nc") + if flist_weights != [] or flist_grids != [] or flist_meta != []: + for f in flist_meta + flist_weights + flist_grids: + if dryrun or verbose: + print(f" - {f}") + if not dryrun: + os.remove(f) + else: + if dryrun or verbose: + print("No weight or grid files found. Cache empty?") + elif dryrun: + print("No weight or grid files found. Cache empty?") + + # Reinitialize local weights cache + if not dryrun: + if not weights_dir_init: + weights_dir_init = weights_dir + weights_cache_init(weights_dir_init) + if verbose: + print(f"Initialized new weights cache at {weights_dir_init}") + + +class Grid: + """Create a Grid object that is suitable to serve as source or target grid of the Weights class. + + Pre-processes coordinate variables of input dataset (eg. create or read dataset from input, + reformat, generate bounds, identify duplicated and collapsing cells, determine zonal / east-west extent). + + Parameters + ---------- + ds : xr.Dataset or xr.DataArray, optional + Uses horizontal coordinates of an xarray.Dataset or xarray.DataArray to create a Grid object. + The default is None. + grid_id : str, optional + Create the Grid object from a selection of pre-defined grids, e.g. "1deg" or "2pt5deg". + The grids are provided via the roocs_grids package (https://github.com/roocs/roocs-grids). + A special setting is "adaptive"/"auto", which requires the parameter 'ds' to be specified as well, + and creates a regular lat-lon grid of the same extent and approximate resolution as the grid + described by 'ds'. The default is None. + grid_instructor : tuple, float or int, optional + Create a regional or global regular lat-lon grid using xESMF utility functions. + - Global grid: grid_instructor = (lon_step, lat_step) or grid_instructor = step + - Regional grid: grid_instructor = (lon_start, lon_end, lon_step, lat_start, lat_end, lat_step) + or grid_instructor = (start, end, step). The default is None. + compute_bounds : bool, optional + Compute latitude and longitude bounds if the dataset has none defined. + The default is False. + + """ + + def __init__( + self, + ds: xr.Dataset | xr.DataArray | None = None, + grid_id: str | None = None, + grid_instructor: tuple | float | int | None = None, + compute_bounds: bool | None = False, + ): + """Initialise the Grid object. Supporting only 2D horizontal grids.""" + # All attributes - defaults + self.type = None + self.format = None + self.extent = None + self.nlat = 0 + self.nlon = 0 + self.ncells = 0 + self.lat = None + self.lon = None + self.lat_bnds = None + self.lon_bnds = None + self.mask = None + self.source = None + self.hash = None + self.coll_mask = None + self.contains_collapsed_cells = None + self.contains_duplicated_cells = None + + # Create grid_instructor as empty tuple if None + grid_instructor = grid_instructor or tuple() + + # Grid from Dataset/DataArray, grid_instructor or grid_id + if isinstance(ds, (xr.Dataset, xr.DataArray)): + if grid_id in ["auto", "adaptive"]: + self._grid_from_ds_adaptive(ds) + else: + self.ds = ds + self.format = self.detect_format() + self.source = "Dataset" + elif grid_instructor: + self._grid_from_instructor(grid_instructor) + elif grid_id: + self._grid_from_id(grid_id) + else: + raise InvalidParameterValue( + "xarray.Dataset, grid_id or grid_instructor have to be specified as input." + ) + + # Force format CF + if self.format not in ["CF"]: + self.grid_reformat(grid_format="CF") + + # Detect latitude and longitude coordinates + self.lat = self.detect_coordinate("latitude") + self.lon = self.detect_coordinate("longitude") + self.lat_bnds = self.detect_bounds(self.lat) + self.lon_bnds = self.detect_bounds(self.lon) + + # Make sure standard_names are set for the coordinates + self.ds[self.lat].attrs["standard_name"] = "latitude" + self.ds[self.lon].attrs["standard_name"] = "longitude" + + # Detect type + if not self.type: + self.type = self.detect_type() + + # Unstagger the grid if necessary (to be done before halo removal - not yet implemented) + self._grid_unstagger() + + # Lon/Lat dimension sizes + self.nlat, self.nlon, self.ncells = self.detect_shape() + + # Extent of the grid (global or regional) + if not self.extent: + self.extent = self.detect_extent() + + # Get a permanent mask if there is + # self.mask = self._detect_mask() + + # Clean coordinate variables out of data_vars + if isinstance(self.ds, xr.Dataset): + self._set_data_vars_and_coords() + + # Detect duplicated grid cells / halos + if self.contains_duplicated_cells is None: + self.contains_duplicated_cells = self._grid_detect_duplicated_cells() + + # Compute bounds if not specified and if possible + if (not self.lat_bnds or not self.lon_bnds) and compute_bounds: + self._compute_bounds() + + # TODO: possible step to use np.around(in_array, decimals [, out_array]) + # 6 decimals corresponds to precision of ~ 0.1m (deg), 6m (rad) + self._cap_precision(coord_precision_hor) + + # Create md5 hash of the coordinate variable arrays + # Takes into account lat/lon + bnds + mask (if defined) + self.hash = self._compute_hash() + + # Detect collapsing grid cells + if self.lat_bnds and self.lon_bnds and self.contains_collapsed_cells is None: + self._grid_detect_collapsed_cells() + + self.title = self._get_title() + + def __str__(self): + """Return short string representation of a Grid object.""" + if self.type == "unstructured": + grid_str = str(self.ncells) + "_cells_grid" + else: + grid_str = str(self.nlat) + "x" + str(self.nlon) + "_cells_grid" + return grid_str + + def __repr__(self): + """Return full representation of a Grid object.""" + info = ( + f"clisops {self.__str__()}\n" + + ( + f"Lat x Lon: {self.nlat} x {self.nlon}\n" + if self.type != "unstructured" + else "" + ) + + f"Gridcells: {self.ncells}\n" + + f"Format: {self.format}\n" + + f"Type: {self.type}\n" + + f"Extent: {self.extent}\n" + + f"Source: {self.source}\n" + + "Bounds? {}\n".format( + self.lat_bnds is not None and self.lon_bnds is not None + ) + + f"Collapsed cells? {self.contains_collapsed_cells}\n" + + f"Duplicated cells? {self.contains_duplicated_cells}\n" + + f"Permanent Mask: {self.mask}\n" + + f"md5 hash: {self.hash}" + ) + return info + + def _get_title(self) -> str: + """Generate a title for the Grid with more information than the basic string representation.""" + if self.source.startswith("Predefined_"): + return ".".join( + ga + for ga in roocs_grids.grid_annotations[ + self.source.replace("Predefined_", "") + ].split(".") + if "land-sea mask" not in ga + ) + else: + if self.type != "unstructured": + return f"{self.extent} {self.type} {self.nlat}x{self.nlon} ({self.ncells} cells) grid." + else: + return f"{self.extent} {self.type} {self.ncells} cells grid." + + def _grid_from_id(self, grid_id): + """Load pre-defined grid from netCDF file.""" + try: + grid_file = roocs_grids.get_grid_file(grid_id) + grid = xr.open_dataset(grid_file) + except KeyError: + raise KeyError(f"The grid_id '{grid_id}' you specified does not exist.") + + # Set attributes + self.ds = grid + self.source = "Predefined_" + grid_id + self.type = "regular_lat_lon" + self.format = self.detect_format() + + @require_xesmf + def _grid_from_instructor(self, grid_instructor: tuple | float | int): + """Process instructions to create regional or global grid (uses xESMF utility functions).""" + # Create tuple of length 1 if input is either float or int + if isinstance(grid_instructor, (int, float)): + grid_instructor = (grid_instructor,) + + # Call xesmf.util functions to create the grid + if len(grid_instructor) not in [1, 2, 3, 6]: + raise InvalidParameterValue( + "The grid_instructor has to be a tuple of length 1, 2, 3 or 6." + ) + elif len(grid_instructor) in [1, 2]: + grid = xe.util.grid_global(grid_instructor[0], grid_instructor[-1]) + elif len(grid_instructor) in [3, 6]: + grid = xe.util.grid_2d( + grid_instructor[0], + grid_instructor[1], + grid_instructor[2], + grid_instructor[-3], + grid_instructor[-2], + grid_instructor[-1], + ) + + # Set attributes + self.ds = grid + self.source = "xESMF" + self.type = "regular_lat_lon" + self.format = "xESMF" + + @require_xesmf + def _grid_from_ds_adaptive(self, ds: xr.Dataset | xr.DataArray): + """Create Grid of similar extent and resolution of input dataset.""" + # TODO: dachar/daops to deal with missing values occuring in the coordinate variables + # while no _FillValue/missing_value attribute is set + # -> FillValues else might get selected as minimum/maximum lat/lon value + # since they are not masked + + # Create temporary Grid object out of input dataset + grid_tmp = Grid(ds=ds) + + # Determine "edges" of the grid + xfirst = float(grid_tmp.ds[grid_tmp.lon].min()) + xlast = float(grid_tmp.ds[grid_tmp.lon].max()) + yfirst = float(grid_tmp.ds[grid_tmp.lat].min()) + ylast = float(grid_tmp.ds[grid_tmp.lat].max()) + + # fix for regional grids that wrap around the Greenwich meridian + if grid_tmp.extent == "regional" and (xfirst > 180 or xlast > 180): + grid_tmp.ds.lon.data = grid_tmp.ds.lon.where( + grid_tmp.ds.lon <= 180, grid_tmp.ds.lon - 360.0 + ) + xfirst = float(grid_tmp.ds[grid_tmp.lon].min()) + xlast = float(grid_tmp.ds[grid_tmp.lon].max()) + + # For unstructured grids: + # Distribute the number of grid cells to nlat and nlon, in proportion + # to extent in meridional and zonal direction + if grid_tmp.type == "unstructured": + xsize = int( + sqrt(abs(xlast - xfirst) / abs(ylast - yfirst) * grid_tmp.ncells) + ) + ysize = int( + sqrt(abs(ylast - yfirst) / abs(xlast - xfirst) * grid_tmp.ncells) + ) + # Else, use nlat and nlon of the dataset + else: + xsize = grid_tmp.nlon + ysize = grid_tmp.nlat + + # Compute meridional / zonal resolution (=increment) + xinc = (xlast - xfirst) / (xsize - 1) + yinc = (ylast - yfirst) / (ysize - 1) + xrange = [0.0, 360.0] if xlast > 180 else [-180.0, 180.0] + xfirst = xfirst - xinc / 2.0 + xlast = xlast + xinc / 2.0 + xfirst = xfirst if xfirst > xrange[0] - xinc / 2.0 else xrange[0] + xlast = xlast if xlast < xrange[1] + xinc / 2.0 else xrange[1] + yfirst = yfirst - yinc / 2.0 + ylast = ylast + yinc / 2.0 + yfirst = yfirst if yfirst > -90.0 else -90.0 + ylast = ylast if ylast < 90.0 else 90.0 + + # Create regular lat-lon grid with these specifics + self._grid_from_instructor((xfirst, xlast, xinc, yfirst, ylast, yinc)) + + def grid_reformat(self, grid_format: str, keep_attrs: bool = False): + """Reformat the Dataset attached to the Grid object to a target format. + + Parameters + ---------- + grid_format : str + Target format of the reformat operation. Yet supported are 'SCRIP', 'CF', 'xESMF'. + keep_attrs : bool + Whether to keep the global attributes. + + Returns + ------- + ds_ref : xarray.Dataset + Reformatted dataset. + """ + # TODO: Extend for formats CF, xESMF, ESMF, UGRID, SCRIP + # If CF and self.type=="regular_lat_lon": + # ensure lat/lon are 1D each and bounds are nlat,2 and nlon,2 + # TODO: When 2D coordinates will be changed to 1D index coordinates + # xarray.assign_coords might be necessary, or alternatively, + # define a new Dataset and move all data_vars and aux. coords across. + + # Generate reformat operation string + reformat_operation = "reformat_" + self.format + "_to_" + grid_format + + # Conduct reformat operation if defined in clisops.utils.dataset_utils + if hasattr(clidu, reformat_operation): + self.ds = getattr(clidu, reformat_operation)( + ds=self.ds, keep_attrs=keep_attrs + ) + self.format = grid_format + else: + raise Exception( + "Converting the grid format from %s to %s is not yet supported." + % (self.format, grid_format) + ) + + def _grid_unstagger(self) -> None: + """Interpolate to cell center from cell edges, rotate vector variables in lat/lon direction. + + Warning + ------- + This method is not yet implemented. + """ + # TODO + # Plan: + # Check if it is vector and not scalar data (eg. by variable name? No other idea yet.) + # Unstagger if needed. + # a) Provide the unstaggered grid (from another dataset with scalar variable) or provide + # the other vector component? One of both might be required. + # b) Rotate the vector in zonal / meridional direction and interpolate to + # cell center of unstaggered grid + # c) Flux direction seems to be important for the rotation (see cdo mrotuvb), how to infer that? + # d) Grids staggered in vertical direction, w-component? Is that important at all for + # horizontal regridding, maybe only for 3D-unstructured grids? + # All in all a quite impossible task to automatise this process. + pass + + def _grid_detect_duplicated_cells(self) -> bool: + """Detect a possible grid halo / duplicated cells.""" + # Create array of (ilat, ilon) tuples + if self.ds[self.lon].ndim == 2 or ( + self.ds[self.lon].ndim == 1 and self.type == "unstructured" + ): + latlon_halo = np.array( + list( + zip( + self.ds[self.lat].values.ravel(), + self.ds[self.lon].values.ravel(), + ) + ), + dtype=("float32,float32"), + ).reshape(self.ds[self.lon].values.shape) + else: + latlon_halo = list() + + # For 1D regular_lat_lon + if isinstance(latlon_halo, list): + dup_rows = [ + i + for i in list(range(self.ds[self.lat].shape[0])) + if i not in np.unique(self.ds[self.lat], return_index=True)[1] + ] + dup_cols = [ + i + for i in list(range(self.ds[self.lon].shape[0])) + if i not in np.unique(self.ds[self.lon], return_index=True)[1] + ] + if dup_cols != [] or dup_rows != []: + return True + + # For 1D unstructured + elif self.type == "unstructured" and self.ds[self.lon].ndim == 1: + mask_duplicates = self._create_duplicate_mask(latlon_halo) + dup_cells = np.where(mask_duplicates is True)[0] + if dup_cells.size > 0: + return True + + # For 2D coordinate variables + else: + mask_duplicates = self._create_duplicate_mask(latlon_halo) + # All duplicate rows indices: + dup_rows = list() + for i in range(mask_duplicates.shape[0]): + if all(mask_duplicates[i, :]): + dup_rows.append(i) + # All duplicate columns indices: + dup_cols = list() + for j in range(mask_duplicates.shape[1]): + if all(mask_duplicates[:, j]): + dup_cols.append(j) + for i in dup_rows: + mask_duplicates[i, :] = False + for j in dup_cols: + mask_duplicates[:, j] = False + # All duplicate rows indices: + dup_part_rows = list() + for i in range(mask_duplicates.shape[0]): + if any(mask_duplicates[i, :]): + dup_part_rows.append(i) + # All duplicate columns indices: + dup_part_cols = list() + for j in range(mask_duplicates.shape[1]): + if any(mask_duplicates[:, j]): + dup_part_cols.append(j) + if ( + dup_part_cols != [] + or dup_part_rows != [] + or dup_cols != [] + or dup_rows != [] + ): + return True + return False + + @staticmethod + def _create_duplicate_mask(arr): + """Create duplicate mask helper function.""" + arr_flat = arr.ravel() + mask = np.zeros_like(arr_flat, dtype=bool) + mask[np.unique(arr_flat, return_index=True)[1]] = True + mask_duplicates = np.where(mask, False, True).reshape(arr.shape) + return mask_duplicates + + def detect_format(self) -> str: + """Detect format of a dataset. Yet supported are 'CF', 'SCRIP', 'xESMF'. + + Returns + ------- + str + The format, if supported. Else raises an Exception. + """ + return clidu.detect_format(ds=self.ds) + + def detect_type(self) -> str: + """Detect type of the grid as one of "regular_lat_lon", "curvilinear", or "unstructured". + + Otherwise, will issue an Exception if grid type is not supported. + + Returns + ------- + str + The detected grid type. + """ + # TODO: Extend for other formats for regular_lat_lon, curvilinear / rotated_pole, unstructured + + if self.format == "CF": + # 1D coordinate variables + if self.ds[self.lat].ndim == 1 and self.ds[self.lon].ndim == 1: + lat_1D = self.ds[self.lat].dims[0] + lon_1D = self.ds[self.lon].dims[0] + # if lat_1D in ds[var].dims and lon_1D in ds[var].dims: + if not self.lat_bnds or not self.lon_bnds: + if lat_1D == lon_1D: + return "unstructured" + else: + return "regular_lat_lon" + else: + if ( + lat_1D == lon_1D + and all( + [ + self.ds[bnds].ndim == 2 + for bnds in [self.lon_bnds, self.lat_bnds] + ] + ) + and all( + [ + self.ds.dims[dim] > 2 + for dim in [ + self.ds[self.lon_bnds].dims[-1], + self.ds[self.lat_bnds].dims[-1], + ] + ] + ) + ): + return "unstructured" + elif all( + [ + self.ds[bnds].ndim == 2 + for bnds in [self.lon_bnds, self.lat_bnds] + ] + ) and all( + [ + self.ds.dims[dim] == 2 + for dim in [ + self.ds[self.lon_bnds].dims[-1], + self.ds[self.lat_bnds].dims[-1], + ] + ] + ): + return "regular_lat_lon" + else: + raise Exception("The grid type is not supported.") + + # 2D coordinate variables + elif self.ds[self.lat].ndim == 2 and self.ds[self.lon].ndim == 2: + # Test for curvilinear or restructure lat/lon coordinate variables + # TODO: Check if regular_lat_lon despite 2D + # - requires additional function checking + # lat[:,i]==lat[:,j] for all i,j + # lon[i,:]==lon[j,:] for all i,j + # - and if that is the case to extract lat/lon and *_bnds + # lat[:]=lat[:,j], lon[:]=lon[j,:] + # lat_bnds[:, 2]=[min(lat_bnds[:,j, :]), max(lat_bnds[:,j, :])] + # lon_bnds similar + if not self.ds[self.lat].shape == self.ds[self.lon].shape: + raise Exception("The grid type is not supported.") + else: + if not self.lat_bnds or not self.lon_bnds: + return "curvilinear" + else: + # Shape of curvilinear bounds either [nlat, nlon, 4] or [nlat+1, nlon+1] + if list(self.ds[self.lat].shape) + [4] == list( + self.ds[self.lat_bnds].shape + ) and list(self.ds[self.lon].shape) + [4] == list( + self.ds[self.lon_bnds].shape + ): + return "curvilinear" + elif [si + 1 for si in self.ds[self.lat].shape] == list( + self.ds[self.lat_bnds].shape + ) and [si + 1 for si in self.ds[self.lon].shape] == list( + self.ds[self.lon_bnds].shape + ): + return "curvilinear" + else: + raise Exception("The grid type is not supported.") + + # >2D coordinate variables, or coordinate variables of different dimensionality + else: + raise Exception("The grid type is not supported.") + + # Other formats + else: + raise Exception( + "Grid type can only be determined for datasets following the CF conventions." + ) + + def detect_extent(self) -> str: + """Determine the grid extent in zonal / east-west direction ('regional' or 'global'). + + Returns + ------- + str + 'regional' or 'global'. + """ + # TODO: support Units "rad" next to "degree ..." + # TODO: additionally check that leftmost and rightmost lon_bnds touch for each row? + # + # TODO: perform a roll if necessary in case the longitude values are not in the range (0,360) + # - Grids that range for example from (-1. , 359.) + # - Grids that are totally out of range, like GFDL (-300, 60) + # ds=dataset_utils.check_lon_alignment(ds, (0,360)) # does not work yet for this purpose + + # Determine min/max lon/lat values + xfirst = float(self.ds[self.lon].min()) + xlast = float(self.ds[self.lon].max()) + yfirst = float(self.ds[self.lat].min()) + ylast = float(self.ds[self.lat].max()) + + # Approximate the grid resolution + if self.ds[self.lon].ndim == 2 and self.ds[self.lat].ndim == 2: + xsize = self.nlon + ysize = self.nlat + xinc = (xlast - xfirst) / (xsize - 1) + yinc = (ylast - yfirst) / (ysize - 1) + approx_res = (xinc + yinc) / 2.0 + elif self.ds[self.lon].ndim == 1: + if self.type == "unstructured": + # Distribute the number of grid cells to nlat and nlon, + # in proportion to extent in zonal and meridional direction + # TODO: Alternatively one can use the kdtree method to calculate the approx. resolution + # once it is implemented here + xsize = int( + sqrt(abs(xlast - xfirst) / abs(ylast - yfirst) * self.ncells) + ) + ysize = int( + sqrt(abs(ylast - yfirst) / abs(xlast - xfirst) * self.ncells) + ) + xinc = (xlast - xfirst) / (xsize - 1) + yinc = (ylast - yfirst) / (ysize - 1) + approx_res = (xinc + yinc) / 2.0 + else: + approx_res = np.average( + np.absolute( + self.ds[self.lon].values[1:] - self.ds[self.lon].values[:-1] + ) + ) + else: + raise Exception( + "Only 1D and 2D longitude and latitude coordinate variables supported." + ) + + # Check the range of the lon values + atol = 2.0 * approx_res + lon_max = float(self.ds[self.lon].max()) + lon_min = float(self.ds[self.lon].min()) + if lon_min < -atol and lon_min > -180.0 - atol and lon_max < 180.0 + atol: + min_range, max_range = (-180.0, 180.0) + elif lon_min > -atol and lon_max < 360.0 + atol: + min_range, max_range = (0.0, 360.0) + # TODO: for shifted longitudes, eg. (-300,60)? I forgot what it was for but likely it is irrelevant + # elif lon_min < -180.0 - atol or lon_max > 360.0 + atol: + # raise Exception( + # "The longitude values have to be within the range (-180, 360)!" + # ) + # elif lon_max - lon_min > 360.0 - atol and lon_max - lon_min < 360.0 + atol: + # min_range, max_range = ( + # lon_min - approx_xres / 2.0, + # lon_max + approx_xres / 2.0, + # ) + else: + raise Exception( + "The longitude values have to be within the range (-180, 360)." + ) + + # Generate a histogram with bins for sections along a latitudinal circle, + # width of the bins/sections dependent on the resolution in x-direction + extent_hist = np.histogram( + self.ds[self.lon], + bins=np.arange(min_range - approx_res, max_range + approx_res, atol), + ) + + # If the counts for all bins are greater than zero, the grid is considered global in x-direction + # Yet, this information is only needed for xesmf.Regridder, as "periodic in longitude" + # and hence, the extent in y-direction does not matter. + # If at some point the qualitative extent in y-direction has to be checked, one needs to + # take into account that global ocean grids often tend to end at the antarctic coast and do not + # reach up to -90°S. + if np.all(extent_hist[0]): + return "global" + else: + return "regional" + + def _detect_mask(self): + """Detect mask helper function. + + Warning + ------- + Not yet implemented, if at all necessary (e.g. for reformatting to SCRIP etc.). + """ + # TODO + # Plan: + # Depending on the format, the mask is stored as extra variable. + # If self.format=="CF": An extra variable mask could be generated from missing values? + # This could be an extra function of the reformatter with target format xESMF/SCRIP/... + # For CF as target format, this mask could be applied to mask the data for all variables that + # are not coordinate or auxiliary variables (infer from attributes if possible). + # If a vertical dimension is present, this should not be done. + # In general one might be better off with the adaptive masking and this would be + # just a nice to have thing in case of reformatting and storing the grid on disk. + + # ds["mask"]=xr.where(~np.isnan(ds['var'].isel(time=0)), 1, 0).astype(int) + return + + def detect_shape(self) -> tuple[int, int, int]: + """Detect the shape of the grid. + + Returns a tuple of (nlat, nlon, ncells). For an unstructured grid nlat and nlon are not defined + and therefore the returned tuple will be (ncells, ncells, ncells). + + Returns + ------- + int + Number of latitude points in the grid. + int + Number of longitude points in the grid. + int + Number of cells in the grid. + """ + # Call clisops.utils.dataset_utils function + return clidu.detect_shape( + ds=self.ds, lat=self.lat, lon=self.lon, grid_type=self.type + ) + + def detect_coordinate(self, coord_type: str) -> str: + """Use cf_xarray to obtain the variable name of the requested coordinate. + + Parameters + ---------- + coord_type : str + Coordinate type understood by cf-xarray, eg. 'lat', 'lon', ... + + Raises + ------ + AttributeError + Raised if the requested coordinate cannot be identified. + + Returns + ------- + str + Coordinate variable name. + """ + # Make use of cf-xarray accessor + coord = self.ds.cf[coord_type] + # coord = get_coord_by_type(self.ds, coord_type, ignore_aux_coords=False) + + # Return the name of the coordinate variable + try: + return coord.name + except AttributeError: + raise AttributeError( + "A %s coordinate cannot be identified in the dataset." % coord_type + ) + + def detect_bounds(self, coordinate: str) -> str | None: + """Use cf_xarray to obtain the variable name of the requested coordinates bounds. + + Parameters + ---------- + coordinate : str + Name of the coordinate variable to determine the bounds from. + + Returns + ------- + str, optional + Returns the variable name of the requested coordinate bounds. + Returns None if the variable has no bounds or if they cannot be identified. + """ + try: + return self.ds.cf.bounds[coordinate][0] + except (KeyError, AttributeError): + warnings.warn( + "For coordinate variable '%s' no bounds can be identified." % coordinate + ) + return + + def _grid_detect_collapsed_cells(self): + """Detect collapsing grid cells. Requires defined bounds.""" + mask_lat = self._create_collapse_mask(self.ds[self.lat_bnds].data) + mask_lon = self._create_collapse_mask(self.ds[self.lon_bnds].data) + # for regular lat-lon grids, create 2D coordinate arrays + if ( + mask_lat.shape != mask_lon.shape + and mask_lat.ndim == 1 + and mask_lon.ndim == 1 + ): + mask_lon, mask_lat = np.meshgrid(mask_lon, mask_lat) + self.coll_mask = mask_lat | mask_lon + self.contains_collapsed_cells = bool(np.any(self.coll_mask)) + + @staticmethod + def _create_collapse_mask(arr): + """Grid cells collapsing to lines or points.""" + orig_shape = arr.shape[:-1] # [nlon, nlat, nbnds] -> [nlon, nlat] + arr_flat = arr.reshape(-1, arr.shape[-1]) # -> [nlon x nlat, nbnds] + arr_set = np.apply_along_axis(lambda x: len(set(x)), -1, arr_flat) + mask = np.zeros(arr_flat.shape[:-1], dtype=bool) + mask[arr_set == 1] = True + return mask.reshape(orig_shape) + + def _cap_precision(self, decimals: int) -> None: + """Round horizontal coordinate variables to specified precision using numpy.around. + + Parameters + ---------- + decimals : int + The decimal position / precision to round to. + + Returns + ------- + None + """ + # TODO: extend for vertical axis for vertical interpolation usecase + # 6 decimals corresponds to hor. precision of ~ 0.1m (deg), 6m (rad) + coord_dict = {} + attr_dict = {} + encoding_dict = {} + + # Assign the rounded values as new coordinate variables + for coord in [self.lat_bnds, self.lon_bnds, self.lat, self.lon]: + if coord: + attr_dict.update({coord: self.ds[coord].attrs}) + encoding_dict.update({coord: self.ds[coord].encoding}) + coord_dict.update( + { + coord: ( + self.ds[coord].dims, + np.around(self.ds[coord].data.astype(np.float64), decimals), + ) + } + ) + + # Restore the original attributes + if coord_dict: + self.ds = self.ds.assign_coords(coord_dict) + # Restore attrs and encoding - is there a proper way to do this?? (TODO) + for coord in [self.lat_bnds, self.lon_bnds, self.lat, self.lon]: + if coord: + self.ds[coord].attrs = attr_dict[coord] + self.ds[coord].encoding = encoding_dict[coord] + + def _compute_hash(self) -> str: + """Compute md5 checksum of each component of the horizontal grid, including a potentially defined mask. + + Stores the individual checksum of each component (lat, lon, lat_bnds, lon_bnds, mask) in a dictionary and + returns an overall checksum. + + Returns + ------- + str + md5 checksum of the checksums of all 5 grid components. + """ + # Create dictionary including the hashes for each grid component and store it as attribute + self.hash_dict = OrderedDict() + for coord, coord_var in OrderedDict( + [ + ("lat", self.lat), + ("lon", self.lon), + ("lat_bnds", self.lat_bnds), + ("lon_bnds", self.lon_bnds), + ("mask", self.mask), + ] + ).items(): + if coord_var: + self.hash_dict[coord] = md5( + str(self.ds[coord_var].values.tobytes()).encode("utf-8") + ).hexdigest() + else: + self.hash_dict[coord] = md5(b"undefined").hexdigest() + + # Return overall checksum for all 5 components + return md5("".join(self.hash_dict.values()).encode("utf-8")).hexdigest() + + def compare_grid( + self, ds_or_Grid: xr.Dataset | Grid, verbose: bool = False + ) -> bool: + """Compare two Grid objects. + + Will compare the checksum of two Grid objects, which depend on the lat and lon coordinate + variables, their bounds and if defined, a mask. + + Parameters + ---------- + ds_or_Grid : xarray.Dataset or Grid + Grid that the current Grid object shall be compared to. + verbose : bool + Whether to also print the result. The default is False. + + Returns + ------- + bool + Returns True if the two Grids are considered identical within the defined precision, else returns False. + """ + # Create temporary Grid object if ds_or_Grid is an xarray object + if isinstance(ds_or_Grid, xr.Dataset) or isinstance(ds_or_Grid, xr.DataArray): + grid_tmp = Grid(ds=ds_or_Grid) + elif isinstance(ds_or_Grid, Grid): + grid_tmp = ds_or_Grid + else: + raise InvalidParameterValue( + "The provided input has to be of one of the types [xarray.DataArray, xarray.Dataset, clisops.core.Grid]." + ) + + # Compare each of the five components and print result if verbose is active + if verbose: + diff = [ + coord_var + for coord_var in self.hash_dict + if self.hash_dict[coord_var] != grid_tmp.hash_dict[coord_var] + ] + if len(diff) > 0: + print(f"The two grids differ in their respective {', '.join(diff)}.") + else: + print("The two grids are considered equal.") + + # Return the result as boolean + return grid_tmp.hash == self.hash + + def _drop_vars(self, keep_attrs: bool = False) -> None: + """Remove all non-necessary (non-horizontal) coords and data_vars of the Grids' xarray.Dataset. + + Parameters + ---------- + keep_attrs : bool + Whether to keep the global attributes. The default is False. + """ + to_keep = [ + var for var in [self.lat, self.lon, self.lat_bnds, self.lon_bnds] if var + ] + to_drop = [ + var + for var in list(self.ds.data_vars) + list(self.ds.coords) + if var not in to_keep + ] + if not keep_attrs: + self.ds.attrs = {} + self.ds = self.ds.drop_vars(names=to_drop) + + def _transfer_coords( + self, source_grid: Grid, keep_attrs: str | bool = True + ) -> None: + """Transfer all non-horizontal coordinates and optionally global attributes between two Grid objects. + + Parameters + ---------- + source_grid : Grid + Source Grid object to transfer the coords from. + keep_attrs : bool or str, optional + Whether to transfer also the global attributes. + False: do not transfer the global attributes. + "target": preserve the global attributes of the target Grid object. + True: transfer the global attributes from source to target Grid object. + The default is True. + + Returns + ------- + None + """ + # Skip all coords with horizontal dimensions or + # coords with no dimensions that are not listed + # in the coordinates attribute of the data_vars + dims_to_skip = set( + source_grid.ds[source_grid.lat].dims + source_grid.ds[source_grid.lon].dims + ) + coordinates_attr = [] + for var in source_grid.ds.data_vars: + cattr = ChainMap( + source_grid.ds[var].attrs, source_grid.ds[var].encoding + ).get("coordinates", "") + if cattr: + coordinates_attr += cattr.split() + to_skip = [ + var + for var in list(source_grid.ds.coords) + if source_grid.ds[var].ndim == 0 and var not in coordinates_attr + ] + to_transfer = [ + var + for var in list(source_grid.ds.coords) + if all([dim not in source_grid.ds[var].dims for dim in dims_to_skip]) + ] + coord_dict = {} + for coord in to_transfer: + if coord not in to_skip: + coord_dict.update({coord: source_grid.ds[coord]}) + if not keep_attrs: + self.ds.attrs = {} + elif keep_attrs == "target": + pass # attrs will be the original target grid values + elif keep_attrs: + self.ds.attrs.update(source_grid.ds.attrs) + else: + raise InvalidParameterValue("Illegal value for the parameter 'keep_attrs'.") + + self.ds = self.ds.assign_coords(coord_dict) + + def _set_data_vars_and_coords(self): + """(Re)set xarray.Dataset.coords appropriately. + + After opening/creating an xarray.Dataset, likely coordinates can be found set as data_vars, + and data_vars set as coords. This method (re)sets the coords. Dimensionless variables that + are not registered in any "coordinates" attribute are per default reset to data_vars, + so xarray does not keep them in the dataset after remapping; an example for this is + "rotated_latitude_longitude". + """ + to_coord = [] + to_datavar = [] + + # Collect all (dimensionless) coordinates + coordinates_attr = [] + for var in self.ds.data_vars: + cattr = ChainMap(self.ds[var].attrs, self.ds[var].encoding).get( + "coordinates", "" + ) + if cattr: + coordinates_attr += cattr.split() + + # Also add cell_measure variables + cell_measures = list() + for cmtype, cm in self.ds.cf.cell_measures.items(): + cell_measures += cm + + # Set as coord for auxiliary coord. variables not supposed to be remapped + if self.ds[self.lat].ndim == 2: + for var in self.ds.data_vars: + if var in cell_measures: + to_coord.append(var) + elif self.ds[var].ndim < 2 and ( + self.ds[var].ndim > 0 or var in coordinates_attr + ): + to_coord.append(var) + elif self.ds[var].ndim == 0: + continue + elif self.ds[var].shape[-2:] != self.ds[self.lat].shape: + to_coord.append(var) + elif self.ds[self.lat].ndim == 1: + for var in self.ds.data_vars: + if var in cell_measures: + to_coord.append(var) + continue + elif self.ds[var].ndim == 0 and var in coordinates_attr: + to_coord.append(var) + continue + elif self.type == "unstructured": + if ( + len(self.ds[var].shape) > 0 + and (self.ds[var].shape[-1],) != self.ds[self.lat].shape + ): + to_coord.append(var) + else: + if not ( + self.ds[var].shape[-2:] == (self.nlat, self.nlon) + or self.ds[var].shape[-2:] == (self.nlon, self.nlat) + ): + to_coord.append(var) + + # Set coordinate bounds as coords + for var in [bnd for bnds in self.ds.cf.bounds.values() for bnd in bnds]: + if var in self.ds.data_vars: + to_coord.append(var) + + # Reset coords for variables supposed to be remapped (eg. ps) + for var in self.ds.coords: + if var not in [self.lat, self.lon] + [ + bnd for bnds in self.ds.cf.bounds.values() for bnd in bnds + ]: + if var in cell_measures: + continue + elif self.ds[var].ndim == 0 and var not in coordinates_attr: + to_datavar.append(var) + elif self.type == "unstructured": + if len(self.ds[var].shape) > 0 and ( + self.ds[var].shape[-1] == self.ncells + and self.ds[var].dims[-1] in self.ds[self.lat].dims + and var not in self.ds.dims + ): + to_datavar.append(var) + else: + if ( + len(self.ds[var].shape) > 0 + and ( + self.ds[var].shape[-2:] == (self.nlat, self.nlon) + or self.ds[var].shape[-2:] == (self.nlon, self.nlat) + ) + and all( + [ + dim in self.ds[var].dims + for dim in list(self.ds[self.lat].dims) + + list(self.ds[self.lon].dims) + ] + ) + ): + to_datavar.append(var) + + # Call xarray.Dataset.(re)set_coords + if to_coord: + self.ds = self.ds.set_coords(list(set(to_coord))) + if to_datavar: + self.ds = self.ds.reset_coords(list(set(to_datavar))) + + def _compute_bounds(self): + """Compute bounds for regular (rectangular or curvilinear) grids. + + The bounds will be attached as coords to the xarray.Dataset of the Grid object. + If no bounds can be created, a warning is issued. + """ + # TODO: This can be a public method as well, but then collapsing grid cells have + # to be detected within this function. + + # Bounds cannot be computed if there are duplicated cells + if self.contains_duplicated_cells: + raise Exception( + "This grid contains duplicated cell centers. Therefore bounds cannot be computed." + ) + + # Bounds are only possible for xarray.Datasets + if not isinstance(self.ds, xr.Dataset): + raise InvalidParameterValue( + "Bounds can only be attached to xarray.Datasets, not to xarray.DataArrays." + ) + if ( + np.amin(self.ds[self.lat].values) < -90.0 + or np.amax(self.ds[self.lat].values) > 90.0 + ): + warnings.warn( + "At least one latitude value exceeds [-90,90]. The bounds could not be calculated." + ) + return + if self.ncells < 3: + warnings.warn( + "The latitude and longitude axes need at least 3 entries" + " to be able to calculate the bounds." + ) + return + + # Use clisops.utils.dataset_utils functions to generate the bounds + if self.format == "CF": + if ( + np.amin(self.ds[self.lat].values) < -90.0 + or np.amax(self.ds[self.lat].values) > 90.0 + ): + warnings.warn("At least one latitude value exceeds [-90,90].") + return + if self.ncells < 3: + warnings.warn( + "The latitude and longitude axes need at least 3 entries" + " to be able to calculate the bounds." + ) + return + if self.type == "curvilinear": + self.ds = clidu.generate_bounds_curvilinear( + ds=self.ds, lat=self.lat, lon=self.lon + ) + elif self.type == "regular_lat_lon": + self.ds = clidu.generate_bounds_rectilinear( + ds=self.ds, lat=self.lat, lon=self.lon + ) + else: + warnings.warn( + "The bounds cannot be calculated for grid_type '%s' and format '%s'." + % (self.type, self.format) + ) + return + + # Add common set of attributes and set as coordinates + self.ds = self.ds.set_coords(["lat_bnds", "lon_bnds"]) + self.ds = clidu.add_hor_CF_coord_attrs( + ds=self.ds, lat=self.lat, lon=self.lon + ) + + # Set the Class attributes + self.lat_bnds = "lat_bnds" + self.lon_bnds = "lon_bnds" + + # Issue warning + warnings.warn( + "Successfully calculated a set of latitude and longitude bounds." + " They might, however, differ from the actual bounds of the model grid." + ) + else: + warnings.warn( + "The bounds cannot be calculated for grid_type '%s' and format '%s'." + % (self.type, self.format) + ) + return + + def to_netcdf( + self, + folder: str | Path | None = "./", + filename: str | None = "", + grid_format: str | None = "CF", + keep_attrs: bool | None = True, + ): + """Store a copy of the horizontal Grid as netCDF file on disk. + + Define output folder, filename and output format (currently only 'CF' is supported). + Does not overwrite an existing file. + + Parameters + ---------- + folder : str or Path, optional + Output folder. The default is the current working directory "./". + filename : str, optional + Output filename, to be defined separately from folder. The default is 'grid_.nc'. + grid_format : str, optional + The format the grid information shall be stored as (in terms of variable attributes and dimensions). + The default is "CF", which is also the only supported output format currently supported. + keep_attrs : bool, optional + Whether to store the global attributes in the output netCDF file. The default is True. + """ + # Check inputs + if filename: + if "/" in str(filename): + raise Exception( + "Target directory and filename have to be passed separately." + ) + filename = Path(folder, filename).as_posix() + else: + filename = Path(folder, "grid_" + self.hash + ".nc").as_posix() + + # Write to disk (just horizontal coordinate variables + global attrs) + # if not written by another process + if not os.path.isfile(filename): + LOCK = filename + ".lock" + lock_obj = FileLock(LOCK) + try: + lock_obj.acquire(timeout=10) + locked = False + except Exception as exc: + if str(exc) == f"Could not obtain file lock on {LOCK}": + locked = True + else: + locked = False + if locked: + warnings.warn( + f"Could not write grid '{filename}' to cache because a lockfile of " + "another process exists." + ) + else: + try: + # Create a copy of the Grid object with just the horizontal grid information + grid_tmp = Grid(ds=self.ds) + if grid_tmp.format != grid_format: + grid_tmp.reformat(grid_format) + grid_tmp._drop_vars(keep_attrs=keep_attrs) + grid_tmp.ds.attrs.update({"clisops": __clisops_version__}) + + # Workaround for the following "features" of xarray: + # 1 # "When an xarray Dataset contains non-dimensional coordinates that do not + # share dimensions with any of the variables, these coordinate variable + # names are saved under a “global” "coordinates" attribute. This is not + # CF-compliant but again facilitates roundtripping of xarray datasets." + # 2 # "By default, variables with float types are attributed a _FillValue of NaN + # in the output file, unless explicitly disabled with an encoding + # {'_FillValue': None}." + if grid_tmp.lat_bnds and grid_tmp.lon_bnds: + grid_tmp.ds = grid_tmp.ds.reset_coords( + [grid_tmp.lat_bnds, grid_tmp.lon_bnds] + ) + grid_tmp.ds[grid_tmp.lat_bnds].encoding["_FillValue"] = None + grid_tmp.ds[grid_tmp.lon_bnds].encoding["_FillValue"] = None + + # Call to_netcdf method of xarray.Dataset + grid_tmp.ds.to_netcdf(filename) + finally: + lock_obj.release() + else: + # Issue a warning if the file already exists + # Not raising an exception since this method is also used to save + # grid files to the local cache + warnings.warn(f"The file '{Path(folder, filename)}' already exists.") + + +class Weights: + """Creates remapping weights out of two Grid objects serving as source and target grid. + + Reads weights from cache if possible. Reads weights from disk if specified (not yet implemented). + In the latter case, the weight file format has to be supported, to reformat it to xESMF format. + + Parameters + ---------- + grid_in : Grid + Grid object serving as source grid. + grid_out : Grid + Grid object serving as target grid. + method : str + Remapping method the weights should be / have been calculated with. One of ["nearest_s2d", + "bilinear", "conservative", "patch"] if weights have to be calculated. Free text if weights + are read from disk. + from_disk : str, optional + Not yet implemented. Instead of calculating the regridding weights (or reading them from + the cache), read them from disk. The default is None. + format: str, optional + Not yet implemented. When reading weights from disk, the input format may be specified. + If omitted, there will be an attempt to detect the format. The default is None. + """ + + @require_xesmf + def __init__( + self, + grid_in: Grid, + grid_out: Grid, + method: str, + from_disk: str | Path | None = None, + format: str | None = None, + ): + """Initialize Weights object, incl. calculating / reading the weights.""" + if not isinstance(grid_in, Grid) or not isinstance(grid_out, Grid): + raise InvalidParameterValue( + "Input and output grids have to be instances of clisops.core.Grid." + ) + self.grid_in = grid_in + self.grid_out = grid_out + self.method = method + + # Compare source and target grid + if grid_in.hash == grid_out.hash: + raise Exception( + "The selected source and target grids are the same. " + "No regridding operation required." + ) + + # Periodic in longitude + # TODO: properly test / check the periodic attribute of the xESMF Regridder. + # The grid.extent check done here might not be suitable to set the periodic attribute: + # global == is grid periodic in longitude + self.periodic = False + try: + if self.grid_in.extent == "global": + self.periodic = True + except AttributeError: + # forced to False for conservative regridding in xesmf + # TODO: check if this is proper behaviour of xesmf + if self.method not in ["conservative", "conservative_normed"]: + warnings.warn( + "The grid extent could not be accessed. " + "It will be assumed that the input grid is not periodic in longitude." + ) + + # Activate ignore degenerate cells setting if collapsing cells are found within the grids. + # The default setting within ESMF is None, not False! + self.ignore_degenerate = ( + True + if ( + self.grid_in.contains_collapsed_cells + or self.grid_out.contains_collapsed_cells + ) + else None + ) + + self.id = self._generate_id() + self.filename = "weights_" + self.id + ".nc" + + if not from_disk: + # Read weights from cache or compute & save to cache + self.format = "xESMF" + self._compute() + else: + # Read weights from disk + self._load_from_disk(filename=from_disk, format=format) + + # Reformat and cache the weights if required + if not self.tool.startswith("xESMF"): + raise NotImplementedError( + f"Reading and reformatting weight files generated by {self.tool} is not supported. " + "The only supported weight file format that is currently supported is xESMF." + ) + self.format = self._detect_format() + self._reformat("xESMF") + + def _compute(self): + """Generate the weights with xESMF or read them from cache.""" + # Read weights_dir from CONFIG + weights_dir = CONFIG["clisops:grid_weights"]["local_weights_dir"] + + # Check if bounds are present in case of conservative remapping + if self.method in ["conservative", "conservative_normed"] and ( + not self.grid_in.lat_bnds + or not self.grid_in.lon_bnds + or not self.grid_out.lat_bnds + or not self.grid_out.lon_bnds + ): + raise Exception( + "For conservative remapping, horizontal grid bounds have to be defined for the source and target grids." + ) + + # Use "Locstream" functionality of xESMF as workaround for unstructured grids. + # Yet, the locstream functionality only supports the nearest neighbour remapping method + locstream_in = False + locstream_out = False + if self.grid_in.type == "unstructured": + locstream_in = True + if self.grid_out.type == "unstructured": + locstream_out = True + if any([locstream_in, locstream_out]) and self.method != "nearest_s2d": + raise NotImplementedError( + "For unstructured grids, the only supported remapping method that is currently supported " + "is nearest neighbour." + ) + + # Read weights from cache (= reuse weights) if they are not currently written + # to the cache by another process + # Note: xESMF writes weights to disk if filename is specified and reuse_weights==False + # (latter is default) else it will create a default filename and weights can + # be manually written to disk with Regridder.to_netcdf(filename). + # Weights are read from disk by xESMF if filename is specified and reuse_weights==True. + lock_obj = create_lock(Path(weights_dir, self.filename + ".lock").as_posix()) + if not lock_obj: + warnings.warn( + f"Could not reuse cached weights '{self.filename}' because a " + "lockfile of another process exists that is writing to that file." + ) + reuse_weights = False + regridder_filename = None + else: + regridder_filename = Path(weights_dir, self.filename).as_posix() + if os.path.isfile(regridder_filename): + reuse_weights = True + else: + reuse_weights = False + + try: + # Read the tool & version the weights have been computed with - backup: current version + self.tool = self._read_info_from_cache("tool") + if not self.tool: + self.tool = "xESMF_v" + xe.__version__ + + # Call xesmf.Regridder + self.regridder = xe.Regridder( + self.grid_in.ds, + self.grid_out.ds, + self.method, + periodic=self.periodic, + locstream_in=locstream_in, + locstream_out=locstream_out, + ignore_degenerate=self.ignore_degenerate, + unmapped_to_nan=True, + filename=regridder_filename, + reuse_weights=reuse_weights, + ) + + # Save Weights to cache + self._save_to_cache(lock_obj) + finally: + # Release file lock + if lock_obj: + lock_obj.release() + + # The default filename is important for later use, so it needs to be reset. + self.regridder.filename = self.regridder._get_default_filename() + + def _generate_id(self) -> str: + """Create a unique id for a Weights object. + + The id consists of + - hashes / checksums of source and target grid (namely lat, lon, lat_bnds, lon_bnds variables) + - info about periodicity in longitude + - info about collapsing cells + - remapping method + + Returns + ------- + str + The id as str. + """ + peri_dict = {True: "peri", False: "unperi"} + ignore_degenerate_dict = { + None: "no-degen", + True: "skip-degen", + False: "no-skip-degen", + } + wid = "_".join( + filter( + None, + [ + self.grid_in.hash, + self.grid_out.hash, + peri_dict[self.periodic], + ignore_degenerate_dict[self.ignore_degenerate], + self.method, + ], + ) + ) + return wid + + @check_weights_dir + def _save_to_cache(self, store_weights: FileLock | None | bool) -> None: + """Save Weights and source/target grids to cache (netCDF), including metadata (JSON).""" + # Read weights_dir from CONFIG + weights_dir = CONFIG["clisops:grid_weights"]["local_weights_dir"] + + # Compile metadata + grid_in_source = self.grid_in.ds.encoding.get("source", "") + grid_out_source = self.grid_out.ds.encoding.get("source", "") + grid_in_tracking_id = self.grid_in.ds.attrs.get("tracking_id", "") + grid_out_tracking_id = self.grid_out.ds.attrs.get("tracking_id", "") + weights_dic = { + "source_uid": self.grid_in.hash, + "target_uid": self.grid_out.hash, + "source_lat": self.grid_in.lat, + "source_lon": self.grid_in.lon, + "source_lat_bnds": self.grid_in.lat_bnds, + "source_lon_bnds": self.grid_in.lon_bnds, + "source_nlat": self.grid_in.nlat, + "source_nlon": self.grid_in.nlon, + "source_ncells": self.grid_in.ncells, + "source_type": self.grid_in.type, + "source_format": self.grid_in.format, + "source_extent": self.grid_in.extent, + "source_source": grid_in_source, + "source_tracking_id": grid_in_tracking_id, + "target_lat": self.grid_out.lat, + "target_lon": self.grid_out.lon, + "target_lat_bnds": self.grid_out.lat_bnds, + "target_lon_bnds": self.grid_out.lon_bnds, + "target_nlat": self.grid_out.nlat, + "target_nlon": self.grid_out.nlon, + "target_ncells": self.grid_out.ncells, + "target_type": self.grid_out.type, + "target_format": self.grid_out.format, + "target_extent": self.grid_out.extent, + "target_source": grid_out_source, + "target_tracking_id": grid_out_tracking_id, + "format": self.format, + "ignore_degenerate": str(self.ignore_degenerate), + "periodic": str(self.periodic), + "method": self.method, + "uid": self.id, + "filename": self.filename, + "def_filename": self.regridder._get_default_filename(), + "tool": self.tool, + } + + # Save Grid objects to cache + self.grid_in.to_netcdf(folder=weights_dir) + self.grid_out.to_netcdf(folder=weights_dir) + + # Save Weights object (netCDF) and metadata (JSON) to cache if desired + # (usually, if no lockfile exists) + if store_weights: + if not os.path.isfile(Path(weights_dir, self.filename).as_posix()): + self.regridder.to_netcdf(Path(weights_dir, self.filename).as_posix()) + if not os.path.isfile( + Path(weights_dir, Path(self.filename).stem + ".json").as_posix() + ): + with open( + Path(weights_dir, Path(self.filename).stem + ".json").as_posix(), + "w", + ) as weights_dic_path: + json.dump(weights_dic, weights_dic_path, sort_keys=True, indent=4) + + @check_weights_dir + def _read_info_from_cache(self, key: str) -> str | None: + """Read info 'key' from cached metadata of current weight-file. + + Returns the value for the given key, unless the key does not exist in the metadata or the + file cannot be read. In this case, None is returned. + + Parameters + ---------- + key : str + + Returns + ------- + str or None + Value for the given key, or None. + """ + # Read weights_dir from CONFIG + weights_dir = CONFIG["clisops:grid_weights"]["local_weights_dir"] + + # Return requested value if weight and metadata files are present, else return None + if os.path.isfile( + Path(weights_dir, self.filename).as_posix() + ) and os.path.isfile( + Path(weights_dir, Path(self.filename).stem + ".json").as_posix() + ): + with open( + Path(weights_dir, Path(self.filename).stem + ".json").as_posix() + ) as f: + weights_dic = json.load(f) + try: + return weights_dic[key] + except KeyError: + warnings.warn( + f"Requested info {key} does not exist in the metadata" + " of the cached weights." + ) + return + else: + return + + def save_to_disk(self, filename=None, wformat: str = "xESMF") -> None: + """Write weights to disk in a certain format. + + Warning + ------- + This method is not yet implemented. + """ + # TODO: if necessary, reformat weights, then save under specified path. + raise NotImplementedError() + + def _load_from_disk(self, filename=None, format=None) -> None: + """Read and process weights from disk. + + Warning + ------- + This method is not yet implemented. + """ + # TODO: Reformat to other weight-file formats when loading/saving from disk + # if format != "xESMF": + # read file, compare Grid and weight matrix dimensions, + # reformat to xESMF sparse matrix and initialize xesmf.Regridder, + # generate_id, set ignore_degenerate, periodic, method to unknown if cannot be determined + raise NotImplementedError() + + def reformat(self, format_from: str, format_to: str) -> None: + """Reformat remapping weights. + + Warning + ------- + This method is not yet implemented. + """ + raise NotImplementedError() + + def _detect_format(self, ds: xr.Dataset | xr.DataArray) -> None: + """Detect format of remapping weights (read from disk). + + Warning + ------- + This method is not yet implemented. + """ + raise NotImplementedError() + + +@require_xesmf +def regrid( + grid_in: Grid, + grid_out: Grid, + weights: Weights, + adaptive_masking_threshold: float | None = 0.5, + keep_attrs: bool | str = True, +) -> xr.Dataset: + """Perform regridding operation including dealing with dataset and variable attributes. + + Parameters + ---------- + grid_in : Grid + Grid object of the source grid, e.g. created out of source xarray.Dataset. + grid_out : Grid + Grid object of the target grid. + weights : Weights + Weights object, as created by using grid_in and grid_out Grid objects as input. + adaptive_masking_threshold : float, optional + (AMT) A value within the [0., 1.] interval that defines the maximum `RATIO` of missing_values amongst the total + number of data values contributing to the calculation of the target grid cell value. For a fraction [0., AMT[ + of the contributing source data missing, the target grid cell will be set to missing_value, else, it will be + re-normalized by the factor `1./(1.-RATIO)`. Thus, if AMT is set to 1, all source grid cells that contribute to + a target grid cell must be missing in order for the target grid cell to be defined as missing itself. Values + greater than 1 or less than 0 will cause adaptive masking to be turned off. This adaptive masking technique + allows to reuse generated weights for differently masked data (e.g. land-sea masks or orographic masks that vary + with depth / height). The default is 0.5. + keep_attrs : bool or str + Sets the global attributes of the resulting dataset, apart from the ones set by this routine: + True: attributes of grid_in.ds will be in the resulting dataset. + False: no attributes but the ones newly set by this routine + "target": attributes of grid_out.ds will be in the resulting dataset. + The default is True. + + Returns + ------- + xarray.Dataset + The regridded data in form of an xarray.Dataset. + """ + if not isinstance(grid_out.ds, xr.Dataset): + raise InvalidParameterValue( + "The target Grid object 'grid_out' has to be built from an xarray.Dataset" + " and not an xarray.DataArray." + ) + + # Duplicated cells / Halo + if grid_in.contains_duplicated_cells: + warnings.warn( + "The grid of the selected dataset contains duplicated cells. " + "For the conservative remapping method, " + "duplicated grid cells contribute to the resulting value, " + "which is in most parts counter-acted by the applied re-normalization. " + "However, please be wary with the results and consider removing / masking " + "the duplicated cells before remapping." + ) + + # Create attrs + attrs_append = {} + if isinstance(grid_in.ds, xr.Dataset): + if "grid" in grid_in.ds.attrs: + attrs_append["grid_original"] = grid_in.ds.attrs["grid"] + if "grid_label" in grid_in.ds.attrs: + attrs_append["grid_label_original"] = grid_in.ds.attrs["grid_label"] + nom_res_o = grid_in.ds.attrs.pop("nominal_resolution", None) + if nom_res_o: + attrs_append["nominal_resolution_original"] = nom_res_o + # TODO: should nominal_resolution of the target grid be calculated if not specified in the attr? + nom_res_n = grid_out.ds.attrs.pop("nominal_resolution", None) + if nom_res_n: + attrs_append["nominal_resolution"] = nom_res_n + + # Remove all unnecessary coords, data_vars (and optionally attrs) from grid_out.ds + if keep_attrs == "target": + grid_out._drop_vars(keep_attrs=True) + else: + grid_out._drop_vars(keep_attrs=False) + + # TODO: It might in general be sufficient to always act as if the threshold was + # set correctly and let xesmf handle it. But then we might not allow it for + # the bilinear method, as the results do not look too great and I am still + # not sure/convinced adaptive_masking makes sense for this method. + + # Allow Dataset and DataArray as input, but always return a Dataset + if isinstance(grid_in.ds, xr.Dataset): + for data_var in grid_in.ds.data_vars: + if not all( + [ + dim in grid_in.ds[data_var].dims + for dim in grid_in.ds[grid_in.lat].dims + + grid_in.ds[grid_in.lon].dims + ] + ): + continue + if weights.regridder.method in [ + "conservative", + "conservative_normed", + "patch", + ]: + # Re-normalize at least contributions from duplicated cells, if adaptive masking is deactivated + if ( + adaptive_masking_threshold < 0 or adaptive_masking_threshold > 1 + ) and grid_in.contains_duplicated_cells: + adaptive_masking_threshold = 0.0 + grid_out.ds[data_var] = weights.regridder( + grid_in.ds[data_var], + skipna=True, + na_thres=adaptive_masking_threshold, + ) + else: + grid_out.ds[data_var] = weights.regridder( + grid_in.ds[data_var], skipna=False + ) + if keep_attrs: + grid_out.ds[data_var].attrs.update(grid_in.ds[data_var].attrs) + grid_out.ds[data_var].encoding.update(grid_in.ds[data_var].encoding) + + # Transfer all non-horizontal coords (and optionally attrs) from grid_in.ds to grid_out.ds + grid_out._transfer_coords(grid_in, keep_attrs=keep_attrs) + + else: + if ( + weights.regridder.method in ["conservative", "conservative_normed", "patch"] + and 0.0 <= adaptive_masking_threshold <= 1.0 + ): + grid_out.ds[grid_in.ds.name] = weights.regridder( + grid_in.ds, skipna=True, na_thres=adaptive_masking_threshold + ) + else: + grid_out.ds[grid_in.ds.name] = weights.regridder(grid_in.ds, skipna=False) + if keep_attrs: + grid_out.ds[grid_in.ds.name].attrs.update(grid_in.ds.attrs) + grid_out.ds[grid_in.ds.name].encoding.update(grid_in.ds.encoding) + + # Add new attrs + grid_out.ds.attrs.update(attrs_append) + grid_out.ds.attrs.update( + { + "grid": grid_out.title, + "grid_label": "gr", # regridded data reported on the data provider's preferred target grid + "regrid_operation": weights.regridder.filename.split(".")[0], + "regrid_tool": weights.tool, + "regrid_weights_uid": weights.id, + } + ) + return grid_out.ds diff --git a/clisops/etc/roocs.ini b/clisops/etc/roocs.ini index 3cb27db8..ad6d88f1 100644 --- a/clisops/etc/roocs.ini +++ b/clisops/etc/roocs.ini @@ -5,6 +5,14 @@ chunk_memory_limit = 250MiB file_size_limit = 1GB output_staging_dir = +[clisops:grid_weights] +local_weights_dir = /tmp/clisops_grid_weights +remote_weights_svc = + +[clisops:coordinate_precision] +hor_coord_decimals = 6 +vert_coord_decimals = 6 + [project:cordex] file_name_template = {__derive__var_id}_{CORDEX_domain}_{driving_model_id}_{experiment_id}_{driving_model_ensemble_member}_{model_id}_{rcm_version_id}_{frequency}{__derive__time_range}{extra}.{__derive__extension} attr_defaults = diff --git a/clisops/ops/__init__.py b/clisops/ops/__init__.py index 69a57825..7c243230 100644 --- a/clisops/ops/__init__.py +++ b/clisops/ops/__init__.py @@ -1,2 +1,3 @@ from .average import average +from .regrid import regrid from .subset import subset diff --git a/clisops/ops/base_operation.py b/clisops/ops/base_operation.py index 9ed93989..db62b6fb 100644 --- a/clisops/ops/base_operation.py +++ b/clisops/ops/base_operation.py @@ -1,3 +1,4 @@ +from collections import ChainMap from pathlib import Path from typing import List, Union @@ -65,28 +66,39 @@ def _calculate(self): def _remove_redundant_fill_values(self, ds): """ - Get coordinate variables and remove fill values added by xarray (CF conventions say that coordinate variables cannot have missing values). - Get bounds variables and remove fill values added by xarray. + Get coordinate and data variables and remove fill values added by xarray + (CF conventions say that coordinate variables cannot have missing values). See issue: https://github.com/roocs/clisops/issues/224 """ if isinstance(ds, xr.Dataset): - main_var = get_main_variable(ds) - for coord_id in ds[main_var].coords: - # remove fill value from coordinate variables - # if ds.coords[coord_id].dims == (coord_id,): - ds[coord_id].encoding["_FillValue"] = None - # remove fill value from bounds variables if they exist - try: - bnd = ds.cf.get_bounds(coord_id).name - ds[bnd].encoding["_FillValue"] = None - except KeyError: - continue + varlist = list(ds.coords) + list(ds.data_vars) + elif isinstance(ds, xr.DataArray): + varlist = list(ds.coords) + + for var in varlist: + fval = ChainMap(ds[var].attrs, ds[var].encoding).get("_FillValue", None) + mval = ChainMap(ds[var].attrs, ds[var].encoding).get("missing_value", None) + if not fval and not mval: + ds[var].encoding["_FillValue"] = None + elif not mval: + ds[var].encoding["missing_value"] = fval + ds[var].encoding["_FillValue"] = fval + ds[var].attrs.pop("_FillValue", None) + elif not fval: + ds[var].encoding["_FillValue"] = mval + ds[var].encoding["missing_value"] = mval + ds[var].attrs.pop("missing_value", None) + else: + if fval != mval: + raise Exception( + f"The defined _FillValue and missing_value for '{var}' are not the same '{fval}' != '{mval}'." + ) return ds - def _remove_redundant_coordinates_from_bounds(self, ds): + def _remove_redundant_coordinates_attr(self, ds): """ - This method removes redundant coordinates from bounds, example: + This method removes the coordinates attribute added by xarray, example: double time_bnds(time, bnds) ; time_bnds:coordinates = "height" ; @@ -98,13 +110,17 @@ def _remove_redundant_coordinates_from_bounds(self, ds): See issue: https://github.com/roocs/clisops/issues/224 """ if isinstance(ds, xr.Dataset): - main_var = get_main_variable(ds) - for coord_id in ds[main_var].coords: - try: - bnd = ds.cf.get_bounds(coord_id).name - ds[bnd].encoding["coordinates"] = None - except KeyError: - continue + varlist = list(ds.coords) + list(ds.data_vars) + elif isinstance(ds, xr.DataArray): + varlist = list(ds.coords) + + for var in varlist: + cattr = ChainMap(ds[var].attrs, ds[var].encoding).get("coordinates", None) + if not cattr: + ds[var].encoding["coordinates"] = None + else: + ds[var].encoding["coordinates"] = cattr + ds[var].attrs.pop("coordinates", None) return ds def process(self) -> List[Union[xr.Dataset, Path]]: @@ -129,7 +145,7 @@ def process(self) -> List[Union[xr.Dataset, Path]]: # remove fill values from lat/lon/time if required processed_ds = self._remove_redundant_fill_values(processed_ds) # remove redundant coordinates from bounds - processed_ds = self._remove_redundant_coordinates_from_bounds(processed_ds) + processed_ds = self._remove_redundant_coordinates_attr(processed_ds) # Work out how many outputs should be created based on the size # of the array. Manage this as a list of time slices. diff --git a/clisops/ops/regrid.py b/clisops/ops/regrid.py new file mode 100644 index 00000000..9e7ba1bf --- /dev/null +++ b/clisops/ops/regrid.py @@ -0,0 +1,216 @@ +from datetime import datetime as dt +from pathlib import Path +from typing import List, Optional, Union + +import xarray as xr +from loguru import logger +from roocs_utils.exceptions import InvalidParameterValue + +from clisops.core import Grid, Weights +from clisops.core import regrid as core_regrid +from clisops.ops.base_operation import Operation +from clisops.utils.file_namers import get_file_namer + +# from clisops.utils.output_utils import get_output, get_time_slices + +__all__ = [ + "regrid", +] + +supported_regridding_methods = ["conservative", "patch", "nearest_s2d", "bilinear"] + + +class Regrid(Operation): + """Class for regridding operation, extends clisops.ops.base_operation.Operation.""" + + def _get_grid_in( + self, + grid_desc: Union[xr.Dataset, xr.DataArray], + compute_bounds: bool, + ): + """ + Create clisops.core.regrid.Grid object as input grid of the regridding operation. + + Return the Grid object. + """ + if isinstance(grid_desc, (xr.Dataset, xr.DataArray)): + return Grid(ds=grid_desc, compute_bounds=compute_bounds) + raise InvalidParameterValue( + "An xarray.Dataset or xarray.DataArray has to be provided as input for the source grid." + ) + + def _get_grid_out( + self, + grid_desc: Union[xr.Dataset, xr.DataArray, int, float, tuple, str], + compute_bounds: bool, + ): + """ + Create clisops.core.regrid.Grid object as target grid of the regridding operation. + + Returns the Grid object + """ + if isinstance(grid_desc, str): + if grid_desc in ["auto", "adaptive"]: + return Grid( + ds=self.ds, grid_id=grid_desc, compute_bounds=compute_bounds + ) + else: + return Grid(grid_id=grid_desc, compute_bounds=compute_bounds) + elif isinstance(grid_desc, (float, int, tuple)): + return Grid(grid_instructor=grid_desc, compute_bounds=compute_bounds) + elif isinstance(grid_desc, (xr.Dataset, xr.DataArray)): + return Grid(ds=grid_desc, compute_bounds=compute_bounds) + else: + # clisops.core.regrid.Grid will raise the exception + return Grid() + + def _get_weights(self, grid_in: Grid, grid_out: Grid, method: str): + """ + Generate the remapping weights using clisops.core.regrid.Weights. + + Returns the Weights object. + """ + return Weights(grid_in=grid_in, grid_out=grid_out, method=method) + + def _resolve_params(self, **params): + """Generate a dictionary of regrid parameters.""" + # all regrid specific paramterers should be passed in via **params + # this is where we resolve them and set self.params as a dict or as separate attributes + # this would be where you make use of your other methods/ attributes e.g. + # get_grid_in(), get_grid_out() and get_weights() to generate the regridder + + adaptive_masking_threshold = params.get("adaptive_masking_threshold", None) + grid = params.get("grid", None) + method = params.get("method", None) + keep_attrs = params.get("keep_attrs", None) + + if method not in supported_regridding_methods: + raise Exception( + "The selected regridding method is not supported. " + "Please choose one of %s." % ", ".join(supported_regridding_methods) + ) + + logger.debug( + f"Input parameters: method: {method}, grid: {grid}, adaptive_masking: {adaptive_masking_threshold}" + ) + + # Compute bounds only when required + compute_bounds = "conservative" in method + + # Create and check source and target grids + grid_in = self._get_grid_in(self.ds, compute_bounds) + grid_out = self._get_grid_out(grid, compute_bounds) + + # Compute the remapping weights + t_start = dt.now() + weights = self._get_weights(grid_in=grid_in, grid_out=grid_out, method=method) + t_end = dt.now() + logger.info( + f"Computed/Retrieved weights in {(t_end-t_start).total_seconds()} seconds." + ) + + # Define params dict + self.params = { + "grid_in": grid_in, + "grid_out": grid_out, + "method": method, + "regridder": weights.regridder, + "weights": weights, + "adaptive_masking_threshold": adaptive_masking_threshold, + "keep_attrs": keep_attrs, + } + + # Input grid / Dataset + self.ds = self.params.get("grid_in").ds + + # Theres no __str__() method for the Regridder object, so I used its filename attribute, + # which specifies a default filename (which has but not much to do with the filename we would give the weight file). + # todo: Better option might be to have the Weights class extend the Regridder class or to define + # a __str__() method for the Weights class. + logger.debug( + "Resolved parameters: grid_in: {}, grid_out: {}, regridder: {}".format( + self.params.get("grid_in").__str__(), + self.params.get("grid_out").__str__(), + self.params.get("regridder").filename, + ) + ) + + def _get_file_namer(self): + """Return the appropriate file namer object.""" + # "extra" is what will go at the end of the file name before .nc + extra = "_regrid-{}-{}".format( + self.params.get("method"), self.params.get("grid_out").__str__() + ) + + namer = get_file_namer(self._file_namer)(extra=extra) + + return namer + + def _calculate(self): + """ + Process the regridding request, calls clisops.core.regrid.regrid(). + + Returns the resulting xarray.Dataset. + """ + # the result is saved by the process() method on the base class + regridded_ds = core_regrid( + self.params.get("grid_in", None), + self.params.get("grid_out", None), + self.params.get("weights", None), + self.params.get("adaptive_masking_threshold", None), + self.params.get("keep_attrs", None), + ) + + return regridded_ds + + +def regrid( + ds: Union[xr.Dataset, str, Path], + *, + method: Optional[str] = "nearest_s2d", + adaptive_masking_threshold: Optional[Union[int, float]] = 0.5, + grid: Optional[ + Union[xr.Dataset, xr.DataArray, int, float, tuple, str] + ] = "adaptive", + output_dir: Optional[Union[str, Path]] = None, + output_type: Optional[str] = "netcdf", + split_method: Optional[str] = "time:auto", + file_namer: Optional[str] = "standard", + keep_attrs: Optional[Union[bool, str]] = True, +) -> List[Union[xr.Dataset, str]]: + """ + Regrid specified input file or xarray object. + + Parameters + ---------- + ds: Union[xr.Dataset, str] + method="nearest_s2d", + adaptive_masking_threshold=0.5, + grid="adaptive", + output_dir: Optional[Union[str, Path]] = None + output_type: {"netcdf", "nc", "zarr", "xarray"} + split_method: {"time:auto"} + file_namer: {"standard", "simple"} + keep_attrs: {True, False, "target"} + + Returns + ------- + List[Union[xr.Dataset, str]] + A list of the regridded outputs in the format selected; str corresponds to file paths if the + output format selected is a file. + + Examples + -------- + | ds: xarray Dataset or "cmip5.output1.MOHC.HadGEM2-ES.rcp85.mon.atmos.Amon.r1i1p1.latest.tas" + | method: "nearest_s2d" + | adaptive_masking_threshold: + | grid: "1deg" + | output_dir: "/cache/wps/procs/req0111" + | output_type: "netcdf" + | split_method: "time:auto" + | file_namer: "standard" + | keep_attrs: True + + """ + op = Regrid(**locals()) + return op.process() diff --git a/clisops/ops/subset.py b/clisops/ops/subset.py index 98861c6e..b46e537b 100644 --- a/clisops/ops/subset.py +++ b/clisops/ops/subset.py @@ -215,7 +215,7 @@ def subset( -------- | ds: xarray Dataset or "cmip5.output1.MOHC.HadGEM2-ES.rcp85.mon.atmos.Amon.r1i1p1.latest.tas" | time: ("1999-01-01T00:00:00", "2100-12-30T00:00:00") or "2085-01-01T12:00:00Z/2120-12-30T12:00:00Z" - | area: (-5.,49.,10.,65) or "0.,49.,10.,65" or [0, 49.5, 10, 65] + | area: (-5.,49.,10.,65) or "0.,49.,10.,65" or [0, 49.5, 10, 65] with the order being lon_0, lat_0, lon_1, lat_1 | level: (1000.,) or "1000/2000" or ("1000.50", "2000.60") | time_components: "year:2000,2004,2008|month:01,02" or {"year": (2000, 2004, 2008), "months": (1, 2)} | output_dir: "/cache/wps/procs/req0111" diff --git a/clisops/utils/common.py b/clisops/utils/common.py index 109cc211..33609c1d 100644 --- a/clisops/utils/common.py +++ b/clisops/utils/common.py @@ -1,9 +1,14 @@ +import functools +import os import sys from pathlib import Path -from typing import List, Union +from types import FunctionType, ModuleType +from typing import List, Optional, Union from loguru import logger +# from roocs_utils.parameter import parameterise + def expand_wildcards(paths: Union[str, Path]) -> list: """Expand the wildcards that may be present in Paths.""" @@ -12,6 +17,37 @@ def expand_wildcards(paths: Union[str, Path]) -> list: return [f for f in Path(path.root).glob(str(Path("").joinpath(*parts)))] +def require_module( + func: FunctionType, + module: ModuleType, + module_name: str, + min_version: Optional[str] = "0.0.0", +): + """Ensure that module is installed before function/method is called, decorator.""" + + @functools.wraps(func) + def wrapper_func(*args, **kwargs): + if module is None: + raise Exception( + f"Package {module_name} >= {min_version} is required to use {func}." + ) + return func(*args, **kwargs) + + return wrapper_func + + +def check_dir(func: FunctionType, dr: Union[str, Path]): + """Ensure that directory dr exists before function/method is called, decorator.""" + if not os.path.isdir(dr): + os.makedirs(dr) + + @functools.wraps(func) + def wrapper_func(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper_func + + def _logging_examples() -> None: """Testing module.""" logger.trace("0") @@ -43,3 +79,27 @@ def enable_logging() -> List[int]: ] ) return logger.configure(**config) + + +def _list_ten(list1d): + """Convert list to string of 10 list elements equally distributed to beginning and end of the list. + + Parameters + ---------- + list1d : list + 1D list. + + Returns + ------- + str + String containing the comma separated 5 first and last elements of the list, with "..." inbetween. + For example "1, 2, 3, 4, 5 ... , 20, 21, 22, 23, 24, 25". + """ + if len(list1d) < 11: + return ", ".join(str(i) for i in list1d) + else: + return ( + ", ".join(str(i) for i in list1d[0:5]) + + " ... " + + ", ".join(str(i) for i in list1d[-5:]) + ) diff --git a/clisops/utils/dataset_utils.py b/clisops/utils/dataset_utils.py index 5cf6657f..29b1ec45 100644 --- a/clisops/utils/dataset_utils.py +++ b/clisops/utils/dataset_utils.py @@ -1,7 +1,7 @@ import warnings -from typing import Optional +from typing import Optional, Tuple -import cf_xarray # noqa +import cf_xarray as cfxr # noqa import cftime import numpy as np import xarray as xr @@ -368,6 +368,550 @@ def adjust_date_to_calendar(da, date, direction="backwards"): ) +def add_hor_CF_coord_attrs( + ds, lat="lat", lon="lon", lat_bnds="lat_bnds", lon_bnds="lon_bnds", keep_attrs=False +): + """ + Add the common CF variable attributes to the horizontal coordinate variables. + + Parameters + ---------- + lat : str, optional + Latitude coordinate variable name. The default is "lat". + lon : str, optional + Longitude coordinate variable name. The default is "lon". + lat_bnds : str, optional + Latitude bounds coordinate variable name. The default is "lat_bnds". + lon_bnds : str, optional + Longitude bounds coordinate variable name. The default is "lon_bnds". + keep_attrs : bool, optional + Whether to keep original coordinate variable attributes if they do not conflict. + In case of a conflict, the attribute value will be overwritten independent of this setting. + The default is False. + + Returns + ------- + xarray.Dataset + The input dataset with updated coordinate variable attributes. + """ + # Define common CF coordinate variable attrs + lat_attrs = { + "bounds": "lat_bnds", + "units": "degrees_north", + "long_name": "latitude", + "standard_name": "latitude", + "axis": "Y", + } + lon_attrs = { + "bounds": "lon_bnds", + "units": "degrees_east", + "long_name": "longitude", + "standard_name": "longitude", + "axis": "X", + } + + # Overwrite or update coordinate variables of input dataset + try: + if keep_attrs: + ds[lat].attrs.update(lat_attrs) + ds[lon].attrs.update(lon_attrs) + else: + ds[lat].attrs = lat_attrs + ds[lon].attrs = lon_attrs + except KeyError: + raise KeyError("Not all specified coordinate variables exist in the dataset.") + + return ds + + +def reformat_SCRIP_to_CF(ds, keep_attrs=False): + """Reformat dataset from SCRIP to CF format. + + Parameters + ---------- + ds : xarray.Dataset + Input dataset in SCRIP format. + keep_attrs: bool + Whether to keep the global attributes. + + Returns + ------- + ds_ref : xarray.Dataset + Reformatted dataset. + """ + source_format = "SCRIP" + target_format = "CF" + SCRIP_vars = [ + "grid_center_lat", + "grid_center_lon", + "grid_corner_lat", + "grid_corner_lon", + "grid_dims", + "grid_area", + "grid_imask", + ] + + if not isinstance(ds, xr.Dataset): + raise InvalidParameterValue( + "Reformat is only possible for Datasets." + " DataArrays have to be CF conformal coordinate variables defined." + ) + + # Cannot reformat data variables yet + if not ( + all([var in SCRIP_vars for var in ds.data_vars]) + and all([coord in SCRIP_vars for coord in ds.coords]) + ): + raise Exception( + "Converting the grid format from %s to %s is not yet possible for data variables." + % (source_format, target_format) + ) + + # center lat and lon arrays will become the lat and lon arrays + lat = ds.grid_center_lat.values.reshape( + (ds.grid_dims.values[1], ds.grid_dims.values[0]) + ).astype(np.float32) + lon = ds.grid_center_lon.values.reshape( + (ds.grid_dims.values[1], ds.grid_dims.values[0]) + ).astype(np.float32) + + # corner coordinates will become lat_bnds and lon_bnds arrays + # regular lat-lon case + # todo: bounds of curvilinear case + if all( + [ + np.array_equal(lat[:, i], lat[:, i + 1], equal_nan=True) + for i in range(ds.grid_dims.values[0] - 1) + ] + ) and all( + [ + np.array_equal(lon[i, :], lon[i + 1, :], equal_nan=True) + for i in range(ds.grid_dims.values[1] - 1) + ] + ): + # regular lat-lon grid: + # - 1D coordinate variables + lat = lat[:, 0] + lon = lon[0, :] + # - reshape vertices from (n,2) to (n+1) for lat and lon axes + lat_b = ds.grid_corner_lat.values.reshape( + ( + ds.grid_dims.values[1], + ds.grid_dims.values[0], + ds.dims["grid_corners"], + ) + ).astype(np.float32) + lon_b = ds.grid_corner_lon.values.reshape( + ( + ds.grid_dims.values[1], + ds.grid_dims.values[0], + ds.dims["grid_corners"], + ) + ).astype(np.float32) + lat_bnds = np.zeros((ds.grid_dims.values[1], 2), dtype=np.float32) + lon_bnds = np.zeros((ds.grid_dims.values[0], 2), dtype=np.float32) + lat_bnds[:, 0] = np.min(lat_b[:, 0, :], axis=1) + lat_bnds[:, 1] = np.max(lat_b[:, 0, :], axis=1) + lon_bnds[:, 0] = np.min(lon_b[0, :, :], axis=1) + lon_bnds[:, 1] = np.max(lon_b[0, :, :], axis=1) + ds_ref = xr.Dataset( + data_vars={}, + coords={ + "lat": (["lat"], lat), + "lon": (["lon"], lon), + "lat_bnds": (["lat", "bnds"], lat_bnds), + "lon_bnds": (["lon", "bnds"], lon_bnds), + }, + ) + # todo: Case of other units (rad) + # todo: Reformat data variables if in ds, apply imask on data variables + # todo: vertical axis, time axis, ... ?! + + # add common coordinate variable attrs + ds_ref = add_hor_CF_coord_attrs(ds=ds_ref) + + # transfer global attributes + if keep_attrs: + ds_ref.attrs.update(ds.attrs) + + return ds_ref + else: + raise Exception( + "Converting the grid format from %s to %s is yet only possible for regular latitude longitude grids." + % (source_format, target_format) + ) + + +def reformat_xESMF_to_CF(ds, keep_attrs=False): + """Reformat dataset from xESMF to CF format. + + Parameters + ---------- + ds : xarray.Dataset + Input dataset in xESMF format. + keep_attrs: bool + Whether to keep the global attributes. + + Returns + ------- + ds_ref : xarray.Dataset + Reformatted dataset. + """ + # source_format="xESMF" + # target_format="CF" + # todo: Check if it is regular_lat_lon, Check dimension sizes + # Define lat, lon, lat_bnds, lon_bnds + lat = ds.lat[:, 0] + lon = ds.lon[0, :] + lat_bnds = np.zeros((lat.shape[0], 2), dtype=np.float32) + lon_bnds = np.zeros((lon.shape[0], 2), dtype=np.float32) + + # From (N+1, M+1) shaped bounds to (N, M, 4) shaped vertices + lat_vertices = cfxr.vertices_to_bounds(ds.lat_b, ("bnds", "lat", "lon")).values + lon_vertices = cfxr.vertices_to_bounds(ds.lon_b, ("bnds", "lat", "lon")).values + + # No longer necessary as of cf_xarray v0.7.5 + # lat_vertices = np.moveaxis(lat_vertices, 0, -1) + # lon_vertices = np.moveaxis(lon_vertices, 0, -1) + + # From (N, M, 4) shaped vertices to (N, 2) and (M, 2) shaped bounds + lat_bnds[:, 0] = np.min(lat_vertices[:, 0, :], axis=1) + lat_bnds[:, 1] = np.max(lat_vertices[:, 0, :], axis=1) + lon_bnds[:, 0] = np.min(lon_vertices[0, :, :], axis=1) + lon_bnds[:, 1] = np.max(lon_vertices[0, :, :], axis=1) + + # Create dataset + ds_ref = xr.Dataset( + data_vars={}, + coords={ + "lat": (["lat"], lat.data), + "lon": (["lon"], lon.data), + "lat_bnds": (["lat", "bnds"], lat_bnds.data), + "lon_bnds": (["lon", "bnds"], lon_bnds.data), + }, + ) + + # todo: Case of other units (rad) + # todo: Reformat data variables if in ds, apply imask on data variables + # todo: vertical axis, time axis, ... ?! + + # add common coordinate variable attrs + ds_ref = add_hor_CF_coord_attrs(ds=ds_ref) + + # transfer global attributes + if keep_attrs: + ds_ref.attrs.update(ds.attrs) + + return ds_ref + # else: + # raise Exception( + # "Converting the grid format from %s to %s is yet only possible for regular latitude longitude grids." + # % (self.format, grid_format) + # ) + + +def detect_format(ds): + """Detect format of a dataset. Yet supported are 'CF', 'SCRIP', 'xESMF'. + + Parameters + ---------- + ds : xr.Dataset + xarray.Dataset of which to detect the format. + + Returns + ------- + str + The format, if supported. Else raises an Exception. + """ + # todo: extend for formats CF, xESMF, ESMF, UGRID, SCRIP + # todo: add more conditions (dimension sizes, ...) + SCRIP_vars = [ + "grid_center_lat", + "grid_center_lon", + "grid_corner_lat", + "grid_corner_lon", + # "grid_imask", "grid_area" + ] + SCRIP_dims = ["grid_corners", "grid_size", "grid_rank"] + + xESMF_vars = [ + "lat", + "lon", + "lat_b", + "lon_b", + # "mask", + ] + xESMF_dims = ["x", "y", "x_b", "y_b"] + + # Test if SCRIP + if all([var in ds for var in SCRIP_vars]) and all( + [dim in ds.dims for dim in SCRIP_dims] + ): + return "SCRIP" + + # Test if xESMF + elif all([var in ds.coords for var in xESMF_vars]) and all( + [dim in ds.dims for dim in xESMF_dims] + ): + return "xESMF" + + # Test if latitude and longitude can be found - standard_names would be set later if undef. + elif ( + "latitude" in ds.cf.standard_names and "longitude" in ds.cf.standard_names + ) or ( + get_coord_by_type(ds, "latitude", ignore_aux_coords=False) is not None + and get_coord_by_type(ds, "longitude", ignore_aux_coords=False) is not None + ): + return "CF" + + else: + raise Exception("The grid format is not supported.") + + +def detect_shape(ds, lat, lon, grid_type) -> Tuple[int, int, int]: + """Detect the shape of the grid. + + Returns a tuple of (nlat, nlon, ncells). For an unstructured grid nlat and nlon are not defined + and therefore the returned tuple will be (ncells, ncells, ncells). + + Parameters + ---------- + ds : xr.Dataset + Dataset containing the grid / coordinate variables. + lat : str + Latitude variable name. + lon : str + Longitude variable name. + grid_type: str + One of "regular_lat_lon", "curvilinear", "unstructured" + + Returns + ------- + int + Number of latitude points in the grid. + int + Number of longitude points in the grid. + int + Number of cells in the grid. + """ + if grid_type not in ["regular_lat_lon", "curvilinear", "unstructured"]: + raise Exception(f"The specified grid_type '{grid_type}' is not supported.") + + if ds[lon].ndim != ds[lon].ndim: + raise Exception( + f"The coordinate variables {lat} and {lon} do not have the same number of dimensions." + ) + elif ds[lat].ndim == 2: + nlat = ds[lat].shape[0] + nlon = ds[lon].shape[1] + ncells = nlat * nlon + elif ds[lat].ndim == 1: + if ds[lat].shape == ds[lon].shape and grid_type == "unstructured": + nlat = ds[lat].shape[0] + nlon = nlat + ncells = nlat + else: + nlat = ds[lat].shape[0] + nlon = ds[lon].shape[0] + ncells = nlat * nlon + else: + raise Exception( + f"The coordinate variables {lat} and {lon} are not 1- or 2-dimensional." + ) + return nlat, nlon, ncells + + +def generate_bounds_curvilinear(ds, lat, lon): + """Compute bounds for curvilinear grids. + + Assumes 2D latitude and longitude coordinate variables. The bounds will be attached as coords + to the xarray.Dataset of the Grid object. If no bounds can be created, a warning is issued. + It is assumed but not ensured that no duplicated cells are present in the grid. + + The bound calculation for curvilinear grids was adapted from + https://github.com/SantanderMetGroup/ATLAS/blob/mai-devel/scripts/ATLAS-data/\ + bash-interpolation-scripts/AtlasCDOremappeR_CORDEX/grid_bounds_calc.py + which based on work by Caillaud Cécile and Samuel Somot from Meteo-France. + + Parameters + ---------- + ds : xarray.Dataset + Dataset to generate the bounds for. + lat : str + Latitude variable name. + lon : str + Longitude variable name. + + Returns + ------- + xarray.Dataset + Dataset with attached bounds. + """ + # Detect shape + nlat, nlon, ncells = detect_shape(ds=ds, lat=lat, lon=lon, grid_type="curvilinear") + + # Rearrange lat/lons + lons_row = ds[lon].data.flatten() + lats_row = ds[lat].data.flatten() + + # Allocate lat/lon corners + lons_cor = np.zeros(lons_row.size * 4) + lats_cor = np.zeros(lats_row.size * 4) + + lons_crnr = np.empty((ds[lon].shape[0] + 1, ds[lon].shape[1] + 1)) + lons_crnr[:] = np.nan + lats_crnr = np.empty((ds[lat].shape[0] + 1, ds[lat].shape[1] + 1)) + lats_crnr[:] = np.nan + + # -------- Calculating corners --------- # + + # Loop through all grid points except at the boundaries + for ilat in range(1, ds[lon].shape[0]): + for ilon in range(1, ds[lon].shape[1]): + # SW corner for each lat/lon index is calculated + lons_crnr[ilat, ilon] = ( + ds[lon][ilat - 1, ilon - 1] + + ds[lon][ilat, ilon - 1] + + ds[lon][ilat - 1, ilon] + + ds[lon][ilat, ilon] + ) / 4.0 + lats_crnr[ilat, ilon] = ( + ds[lat][ilat - 1, ilon - 1] + + ds[lat][ilat, ilon - 1] + + ds[lat][ilat - 1, ilon] + + ds[lat][ilat, ilon] + ) / 4.0 + + # Grid points at boundaries + lons_crnr[0, :] = lons_crnr[1, :] - (lons_crnr[2, :] - lons_crnr[1, :]) + lons_crnr[-1, :] = lons_crnr[-2, :] + (lons_crnr[-2, :] - lons_crnr[-3, :]) + lons_crnr[:, 0] = lons_crnr[:, 1] + (lons_crnr[:, 1] - lons_crnr[:, 2]) + lons_crnr[:, -1] = lons_crnr[:, -2] + (lons_crnr[:, -2] - lons_crnr[:, -3]) + + lats_crnr[0, :] = lats_crnr[1, :] - (lats_crnr[2, :] - lats_crnr[1, :]) + lats_crnr[-1, :] = lats_crnr[-2, :] + (lats_crnr[-2, :] - lats_crnr[-3, :]) + lats_crnr[:, 0] = lats_crnr[:, 1] - (lats_crnr[:, 1] - lats_crnr[:, 2]) + lats_crnr[:, -1] = lats_crnr[:, -2] + (lats_crnr[:, -2] - lats_crnr[:, -3]) + + # ------------ DONE ------------- # + + # Fill in counterclockwise and rearrange + count = 0 + for ilat in range(ds[lon].shape[0]): + for ilon in range(ds[lon].shape[1]): + lons_cor[count] = lons_crnr[ilat, ilon] + lons_cor[count + 1] = lons_crnr[ilat, ilon + 1] + lons_cor[count + 2] = lons_crnr[ilat + 1, ilon + 1] + lons_cor[count + 3] = lons_crnr[ilat + 1, ilon] + + lats_cor[count] = lats_crnr[ilat, ilon] + lats_cor[count + 1] = lats_crnr[ilat, ilon + 1] + lats_cor[count + 2] = lats_crnr[ilat + 1, ilon + 1] + lats_cor[count + 3] = lats_crnr[ilat + 1, ilon] + + count += 4 + + lon_bnds = lons_cor.reshape(nlat, nlon, 4) + lat_bnds = lats_cor.reshape(nlat, nlon, 4) + + # Add to the dataset + ds["lat_bnds"] = ( + (ds[lat].dims[0], ds[lat].dims[1], "vertices"), + lat_bnds, + ) + ds["lon_bnds"] = ( + (ds[lon].dims[0], ds[lon].dims[1], "vertices"), + lon_bnds, + ) + + return ds + + +def generate_bounds_rectilinear(ds, lat, lon): + """Compute bounds for rectilinear grids. + + The bounds will be attached as coords to the xarray.Dataset of the Grid object. + If no bounds can be created, a warning is issued. It is assumed but not ensured that no + duplicated cells are present in the grid. + + Parameters + ---------- + ds : xarray.Dataset + . + lat : str + Latitude variable name. + lon : str + Longitude variable name. + + Returns + ------- + xarray.Dataset + Dataset with attached bounds. + """ + # Detect shape + nlat, nlon, ncells = detect_shape( + ds=ds, lat=lat, lon=lon, grid_type="regular_lat_lon" + ) + + # Assuming lat / lon values are strong monotonically decreasing/increasing + # Latitude / Longitude bounds shaped (nlat, 2) / (nlon, 2) + lat_bnds = np.zeros((ds[lat].shape[0], 2), dtype=np.float32) + lon_bnds = np.zeros((ds[lon].shape[0], 2), dtype=np.float32) + + # lat_bnds + # positive<0 for strong monotonically increasing + # positive>0 for strong monotonically decreasing + positive = ds[lat].values[0] - ds[lat].values[1] + gspacingl = abs(positive) + gspacingu = abs(ds[lat].values[-1] - ds[lat].values[-2]) + if positive < 0: + lat_bnds[1:, 0] = (ds[lat].values[:-1] + ds[lat].values[1:]) / 2.0 + lat_bnds[:-1, 1] = lat_bnds[1:, 0] + lat_bnds[0, 0] = ds[lat].values[0] - gspacingl / 2.0 + lat_bnds[-1, 1] = ds[lat].values[-1] + gspacingu / 2.0 + elif positive > 0: + lat_bnds[1:, 1] = (ds[lat].values[:-1] + ds[lat].values[1:]) / 2.0 + lat_bnds[:-1, 0] = lat_bnds[1:, 1] + lat_bnds[0, 1] = ds[lat].values[0] + gspacingl / 2.0 + lat_bnds[-1, 0] = ds[lat].values[-1] - gspacingu / 2.0 + else: + warnings.warn( + "The bounds could not be calculated since the latitude and/or longitude " + "values are not strong monotonically decreasing/increasing." + ) + return ds + + lat_bnds = np.where(lat_bnds < -90.0, -90.0, lat_bnds) + lat_bnds = np.where(lat_bnds > 90.0, 90.0, lat_bnds) + + # lon_bnds + positive = ds[lon].values[0] - ds[lon].values[1] + gspacingl = abs(positive) + gspacingu = abs(ds[lon].values[-1] - ds[lon].values[-2]) + if positive < 0: + lon_bnds[1:, 0] = (ds[lon].values[:-1] + ds[lon].values[1:]) / 2.0 + lon_bnds[:-1, 1] = lon_bnds[1:, 0] + lon_bnds[0, 0] = ds[lon].values[0] - gspacingl / 2.0 + lon_bnds[-1, 1] = ds[lon].values[-1] + gspacingu / 2.0 + elif positive > 0: + lon_bnds[1:, 1] = (ds[lon].values[:-1] + ds[lon].values[1:]) / 2.0 + lon_bnds[:-1, 0] = lon_bnds[1:, 1] + lon_bnds[0, 1] = ds[lon].values[0] + gspacingl / 2.0 + lon_bnds[-1, 0] = ds[lon].values[-1] - gspacingu / 2.0 + else: + warnings.warn( + "The bounds could not be calculated since the latitude and/or longitude " + "values are not strong monotonically decreasing/increasing." + ) + return ds + + # Add to the dataset + ds["lat_bnds"] = ((ds[lat].dims[0], "bnds"), lat_bnds) + ds["lon_bnds"] = ((ds[lon].dims[0], "bnds"), lon_bnds) + + return ds + + def detect_coordinate(ds, coord_type): """Use cf_xarray to obtain the variable name of the requested coordinate. diff --git a/clisops/utils/output_utils.py b/clisops/utils/output_utils.py index 0471be6c..2bb3a5ed 100644 --- a/clisops/utils/output_utils.py +++ b/clisops/utils/output_utils.py @@ -4,6 +4,7 @@ import tempfile import time from datetime import datetime as dt +from datetime import timedelta as td from pathlib import Path from typing import List, Tuple, Union @@ -247,3 +248,72 @@ def get_output(ds, output_type, output_dir, namer): logger.info(f"Wrote output file: {output_path}") return output_path + + +class FileLock: + """Create and release a lockfile. + + Adapted from https://github.com/cedadev/cmip6-object-store/cmip6_zarr/file_lock.py + """ + + def __init__(self, fpath): + """Initialize Lock for 'fpath'.""" + self._fpath = fpath + dr = os.path.dirname(fpath) + if dr and not os.path.isdir(dr): + os.makedirs(dr) + + self.state = "UNLOCKED" + + def acquire(self, timeout=10): + """Create actual lockfile, raise error if already exists beyond 'timeout'.""" + start = dt.now() + deadline = start + td(seconds=timeout) + + while dt.now() < deadline: + if not os.path.isfile(self._fpath): + Path(self._fpath).touch() + break + + time.sleep(3) + else: + raise Exception(f"Could not obtain file lock on {self._fpath}") + + self.state = "LOCKED" + + def release(self): + """Release lock, i.e. delete lockfile.""" + if os.path.isfile(self._fpath): + try: + os.remove(self._fpath) + except FileNotFoundError: + pass + + self.state = "UNLOCKED" + + +def create_lock(fname: Union[str, Path]): + """Check whether lockfile already exists and else creates lockfile. + + Parameters + ---------- + fname : str + Path of the lockfile to be created. + + Returns + ------- + FileLock object or None. + """ + lock_obj = FileLock(fname) + try: + lock_obj.acquire(timeout=10) + locked = False + except Exception as exc: + if str(exc) == f"Could not obtain file lock on {fname}": + locked = True + else: + raise Exception(exc) + if locked: + return None + else: + return lock_obj diff --git a/docs/api.rst b/docs/api.rst index 435cc11b..d0042d0e 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -18,6 +18,14 @@ Core average functionality :undoc-members: :show-inheritance: +Core regrid functionality +========================= + +.. automodule:: clisops.core.regrid + :members: + :undoc-members: + :show-inheritance: + Subset operation ================ @@ -38,6 +46,15 @@ Average operations :undoc-members: :show-inheritance: +Regrid operation +================= + +.. automodule:: clisops.ops.regrid + :noindex: + :members: + :undoc-members: + :show-inheritance: + Common functions ================ diff --git a/docs/environment.yml b/docs/environment.yml index fc538064..348b294c 100644 --- a/docs/environment.yml +++ b/docs/environment.yml @@ -1,26 +1,45 @@ # conda env create -f environment.yml name: clisops-docs channels: -- conda-forge -- defaults + - conda-forge dependencies: -- sphinx -- nbsphinx -- ipython -- pandoc -- pip -- numpy>=1.16 -- xarray>=0.15 -- pandas>=1.0.3 -- cftime>=1.4.1 -- netCDF4>=1.4 -- xesmf>=0.6.2 -- poppler>=0.67 -- shapely>=1.6 -- geopandas>=0.7 -- dask>=2.6.0 -- gdal<3.5 -- bottleneck>=1.3.1,<1.4 -- pyproj>=2.5 -- roocs-utils>=0.2.1 -- cf_xarray>=0.3.1 + - python >=3.8,<3.12 + - flit + - bottleneck >=1.3.1 + - cf_xarray >=0.8.5 + - cftime >=1.4.1 + - dask >=2.6.0 + - gdal >=3.0 + - geopandas >=0.11 + - loguru >=0.5.3 + - netCDF4 >=1.4 + - numpy >=1.16 + - packaging + - pandas >=1.0.3 + - pooch + - poppler >=0.67 + - pyproj >=3.3.0 + - requests >=2.0 + - roocs-grids >=0.1.2 + - roocs-utils >=0.6.4,<0.7 + - shapely >=1.9 + - xarray >=0.21,<2023.3.0 # https://github.com/pydata/xarray/issues/7794 + - xesmf >=0.8.2 + # Documentation + - cartopy >=0.20.2 + - gitpython + - ipykernel + - ipython + - ipython_genutils + - jupyter_client + - matplotlib >=3.5.2 + - nbsphinx + - pandoc + - sphinx + - sphinx-rtd-theme >=1.0 + # Upstream + - pip + - pip: + - psy-maps +# - cf-xarray @ git+https://github.com/xarray-contrib/cf-xarray/@main#egg=cf-xarray +# - roocs-utils @ git+https://github.com/roocs/roocs-utils.git@master#egg=roocs-utils diff --git a/docs/notebooks/index.rst b/docs/notebooks/index.rst index 893b89c3..6f1ec908 100644 --- a/docs/notebooks/index.rst +++ b/docs/notebooks/index.rst @@ -8,3 +8,4 @@ Examples subset core_subset average_over_dims + regrid diff --git a/docs/notebooks/regrid.ipynb b/docs/notebooks/regrid.ipynb new file mode 100644 index 00000000..9c22807a --- /dev/null +++ b/docs/notebooks/regrid.ipynb @@ -0,0 +1,1033 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "eb091c4d", + "metadata": {}, + "source": [ + "# `clisops` regridding functionalities - powered by `xesmf`\n", + "\n", + "The regridding functionalities of clisops consist of the regridding operator/function `regrid` in `clisops.ops`, allowing one-line remapping of `xarray.Datasets` or `xarray.DataArrays`, while orchestrating the use of classes and functions in `clisops.core`:\n", + "- the `Grid` and `Weights` classes, to check and pre-process input as well as output grids and to generate the remapping weights\n", + "- a `regrid` function, performing the remapping by applying the generated weights on the input data\n", + "\n", + "For the weight generation and the regridding, the [xESMF](https://github.com/pangeo-data/xESMF) `Regridder` class is used, which itself allows an easy application of many of the remapping functionalities of [ESMF](https://earthsystemmodeling.org/)/[ESMPy](https://github.com/esmf-org/esmf/blob/develop/src/addon/ESMPy/README.md)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f2f10f5", + "metadata": {}, + "outputs": [], + "source": [ + "# Imports\n", + "\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import cartopy.crs as ccrs\n", + "import psyplot.project as psy\n", + "import numpy as np\n", + "import xarray as xr\n", + "import cf_xarray as cfxr\n", + "\n", + "from pathlib import Path\n", + "from git import Repo\n", + "# Set required environment variable for ESMPy\n", + "import os \n", + "os.environ['ESMFMKFILE'] = str(Path(os.__file__).parent.parent / 'esmf.mk')\n", + "import xesmf as xe\n", + "\n", + "import clisops as cl # atm. the regrid-main-martin branch of clisops\n", + "import clisops.ops as clops\n", + "import clisops.core as clore\n", + "from clisops.utils import dataset_utils\n", + "from roocs_grids import get_grid_file, grid_dict, grid_annotations\n", + "\n", + "print(f\"Using xarray in version {xr.__version__}\")\n", + "print(f\"Using cf_xarray in version {cfxr.__version__}\")\n", + "print(f\"Using xESMF in version {xe.__version__}\")\n", + "print(f\"Using clisops in version {cl.__version__}\")\n", + "\n", + "xr.set_options(display_style='html')\n", + "\n", + "## Turn off warnings?\n", + "import warnings\n", + "warnings.simplefilter(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fcad6b4", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize test data\n", + "\n", + "# Initialize mini-esgf-data\n", + "MINIESGF_URL=\"https://github.com/roocs/mini-esgf-data\"\n", + "branch = \"master\"\n", + "MINIESGF = Path(Path.home(),\".mini-esgf-data\", branch)\n", + "\n", + "# Retrieve mini-esgf test data\n", + "if not os.path.isdir(MINIESGF):\n", + " repo = Repo.clone_from(MINIESGF_URL, MINIESGF)\n", + " repo.git.checkout(branch)\n", + "else:\n", + " repo = Repo(MINIESGF)\n", + " repo.git.checkout(branch)\n", + " repo.remotes[0].pull()\n", + " \n", + "MINIESGF=Path(MINIESGF,\"test_data\")" + ] + }, + { + "cell_type": "markdown", + "id": "edb25e1f", + "metadata": {}, + "source": [ + "## `clisops.ops.regrid`\n", + "\n", + "One-line remapping with `clisops.ops.regrid`\n", + "```python\n", + "def regrid(\n", + " ds: Union[xr.Dataset, str, Path],\n", + " *,\n", + " method: Optional[str] = \"nearest_s2d\",\n", + " adaptive_masking_threshold: Optional[Union[int, float]] = 0.5,\n", + " grid: Optional[\n", + " Union[xr.Dataset, xr.DataArray, int, float, tuple, str]\n", + " ] = \"adaptive\",\n", + " output_dir: Optional[Union[str, Path]] = None,\n", + " output_type: Optional[str] = \"netcdf\",\n", + " split_method: Optional[str] = \"time:auto\",\n", + " file_namer: Optional[str] = \"standard\",\n", + " keep_attrs: Optional[Union[bool, str]] = True,\n", + ") -> List[Union[xr.Dataset, str]] \n", + "```\n", + "The different options for the `method`, `grid` and `adaptive_masking_threshold` parameters are described in below sections:\n", + "\n", + "* [clisops.core.Grid](#clisops.core.Grid)\n", + "* [clisops.core.Weights](#clisops.core.Weights)\n", + "* [clisops.core.regrid](#clisops.core.regrid)\n" + ] + }, + { + "cell_type": "markdown", + "id": "0aa5d035", + "metadata": {}, + "source": [ + "### Remap a global `xarray.Dataset` to a global 2.5 degree grid using the bilinear method\n", + "\n", + "#### Load the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09bea3a3", + "metadata": {}, + "outputs": [], + "source": [ + "ds_vert_path = Path(MINIESGF, \"badc/cmip6/data/CMIP6/CMIP/MPI-M/MPI-ESM1-2-LR/historical/r1i1p1f1/AERmon/\"\n", + " \"o3/gn/v20190710/o3_AERmon_MPI-ESM1-2-LR_historical_r1i1p1f1_gn_185001.nc\")\n", + "ds_vert = xr.open_dataset(ds_vert_path)\n", + "ds_vert" + ] + }, + { + "cell_type": "markdown", + "id": "3e8e01f4", + "metadata": {}, + "source": [ + "#### Take a look at the grid" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ca60d65", + "metadata": {}, + "outputs": [], + "source": [ + "# Create 2D coordinate variables\n", + "lon,lat = np.meshgrid(ds_vert[\"lon\"].data, ds_vert[\"lat\"].data)\n", + "\n", + "# Plot\n", + "plt.figure(figsize=(8,5))\n", + "plt.scatter(lon[::3, ::3], lat[::3, ::3], s=0.5) \n", + "plt.xlabel('lon')\n", + "plt.ylabel('lat')" + ] + }, + { + "cell_type": "markdown", + "id": "4a085f74", + "metadata": {}, + "source": [ + "#### Remap to global 2.5 degree grid with the bilinear method" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f38b21f7", + "metadata": {}, + "outputs": [], + "source": [ + "ds_remap = clops.regrid(ds_vert, method=\"bilinear\", grid=\"2pt5deg\", output_type=\"xarray\")[0]\n", + "ds_remap" + ] + }, + { + "cell_type": "markdown", + "id": "acbdf431", + "metadata": {}, + "source": [ + "#### Plot the remapped data next to the source data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a18af12", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(ncols=2, figsize=(18,4), subplot_kw={'projection': ccrs.PlateCarree()})\n", + "for ax in axes:\n", + " ax.coastlines()\n", + "\n", + "# Source data\n", + "ds_vert.o3.isel(time=0, lev=0).plot.pcolormesh(ax=axes[0], x=\"lon\", y=\"lat\", shading=\"auto\")\n", + "axes[0].title.set_text(\"Source - MPI-ESM1-2-LR ECHAM6 (T63L47, ~1.9° resolution)\")\n", + "# Remapped data\n", + "ds_remap.o3.isel(time=0, lev=0).plot.pcolormesh(ax=axes[1], x=\"lon\", y=\"lat\", shading=\"auto\")\n", + "axes[1].title.set_text(\"Target - regular lat-lon (2.5° resolution)\")" + ] + }, + { + "cell_type": "markdown", + "id": "6e63e114", + "metadata": {}, + "source": [ + "### Remap regional `xarray.Dataset` to a regional grid of adaptive resolution using the bilinear method\n", + "Adaptive resolution means, that the regular lat-lon target grid will have approximately the same resolution as the source grid." + ] + }, + { + "cell_type": "markdown", + "id": "6434bcca", + "metadata": {}, + "source": [ + "#### Load the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff637d57", + "metadata": {}, + "outputs": [], + "source": [ + "ds_cordex_path = Path(MINIESGF, \"pool/data/CORDEX/data/cordex/output/EUR-22/GERICS/MPI-M-MPI-ESM-LR/\"\n", + " \"rcp85/r1i1p1/GERICS-REMO2015/v1/mon/tas/v20191029/\"\n", + " \"tas_EUR-22_MPI-M-MPI-ESM-LR_rcp85_r1i1p1_GERICS-REMO2015_v1_mon_202101.nc\")\n", + "ds_cordex = xr.open_dataset(ds_cordex_path)\n", + "ds_cordex" + ] + }, + { + "cell_type": "markdown", + "id": "fe1f9feb", + "metadata": {}, + "source": [ + "#### Take a look at the grid" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca6661be", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(8,5))\n", + "plt.scatter(ds_cordex['lon'][::4, ::4], ds_cordex['lat'][::4, ::4], s=0.1) \n", + "plt.xlabel('lon')\n", + "plt.ylabel('lat')" + ] + }, + { + "cell_type": "markdown", + "id": "280c1bb0", + "metadata": {}, + "source": [ + "#### Remap to regional regular lat-lon grid of adaptive resolution with the bilinear method" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47cdfe78", + "metadata": {}, + "outputs": [], + "source": [ + "ds_remap = clops.regrid(ds_cordex, method=\"bilinear\", grid=\"adaptive\", output_type=\"xarray\")[0]\n", + "ds_remap" + ] + }, + { + "cell_type": "markdown", + "id": "6a4d0b86", + "metadata": {}, + "source": [ + "#### Plot the remapped data next to the source data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd8f65b6", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(ncols=2, figsize=(18,4), subplot_kw={'projection': ccrs.PlateCarree()})\n", + "for ax in axes: ax.coastlines()\n", + "\n", + "# Source data\n", + "ds_cordex.tas.isel(time=0).plot.pcolormesh(ax=axes[0], x=\"lon\", y=\"lat\", shading=\"auto\", cmap=\"RdBu_r\")\n", + "axes[0].title.set_text(\"Source - GERICS-REMO2015 (EUR22, ~0.22° resolution)\")\n", + "# Remapped data\n", + "ds_remap.tas.isel(time=0).plot.pcolormesh(ax=axes[1], x=\"lon\", y=\"lat\", shading=\"auto\", cmap=\"RdBu_r\")\n", + "axes[1].title.set_text(\"Target - regional regular lat-lon (adaptive resolution)\")" + ] + }, + { + "cell_type": "markdown", + "id": "b0c2a4b6", + "metadata": {}, + "source": [ + "### Remap unstructured `xarray.Dataset` to a global grid of adaptive resolution using the nearest neighbour method\n", + "\n", + "For unstructured grids, at least for the moment, only the nearest neighbour remapping method is supported.\n", + "\n", + "#### Load the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0b8dc76", + "metadata": {}, + "outputs": [], + "source": [ + "ds_icono_path = Path(MINIESGF, \"badc/cmip6/data/CMIP6/CMIP/MPI-M/ICON-ESM-LR/historical/\"\n", + " \"r1i1p1f1/Omon/thetao/gn/v20210215/\"\n", + " \"thetao_Omon_ICON-ESM-LR_historical_r1i1p1f1_gn_185001.nc\")\n", + "ds_icono = xr.open_dataset(ds_icono_path)\n", + "ds_icono" + ] + }, + { + "cell_type": "markdown", + "id": "b6dfc32e", + "metadata": {}, + "source": [ + "#### Take a look at the grid" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "998dcd7a", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(16,9))\n", + "plt.scatter(ds_icono['longitude'][::2], ds_icono['latitude'][::2], s=0.05) \n", + "plt.xlabel('lon')\n", + "plt.ylabel('lat')" + ] + }, + { + "cell_type": "markdown", + "id": "56f0f48b", + "metadata": {}, + "source": [ + "#### Remap to global grid of adaptive resolution with the nearest neighbour method" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f68a902", + "metadata": {}, + "outputs": [], + "source": [ + "ds_remap = clops.regrid(ds_icono, method=\"nearest_s2d\", grid=\"adaptive\", output_type=\"xarray\")[0]\n", + "ds_remap" + ] + }, + { + "cell_type": "markdown", + "id": "9d0459c6", + "metadata": {}, + "source": [ + "#### Plot source data and remapped data\n", + "\n", + "(Using [psyplot](https://psyplot.github.io/) to plot the unstructured data since xarray does not (yet?) support it.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31d53774", + "metadata": {}, + "outputs": [], + "source": [ + "# Source data\n", + "maps=psy.plot.mapplot(ds_icono_path, cmap=\"RdBu_r\", title=\"Source - ICON-ESM-LR ICON-O (Ruby-0, 40km resolution)\", \n", + " time=[0], lev=[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5fc5d12", + "metadata": {}, + "outputs": [], + "source": [ + "# Remapped data\n", + "plt.figure(figsize=(9,4));\n", + "ax = plt.axes(projection=ccrs.PlateCarree());\n", + "ds_remap.thetao.isel(time=0, lev=0).plot.pcolormesh(ax=ax, x=\"lon\", y=\"lat\", shading=\"auto\",\n", + " cmap=\"RdBu_r\", vmin = -1, vmax=40)\n", + "ax.title.set_text(\"Target - regular lat-lon (adaptive resolution)\")\n", + "ax.coastlines()" + ] + }, + { + "cell_type": "markdown", + "id": "31ca8f33", + "metadata": {}, + "source": [ + "\n", + "\n", + "## `clisops.core.Grid`" + ] + }, + { + "cell_type": "markdown", + "id": "418e5f4d", + "metadata": {}, + "source": [ + "### Create a grid object from an `xarray.Dataset`\n", + "\n", + "#### Load the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8db00e2c", + "metadata": {}, + "outputs": [], + "source": [ + "dso_path = Path(MINIESGF, \"badc/cmip6/data/CMIP6/CMIP/MPI-M/MPI-ESM1-2-HR/historical/r1i1p1f1/Omon/tos/gn/\"\n", + " \"v20190710/tos_Omon_MPI-ESM1-2-HR_historical_r1i1p1f1_gn_185001.nc\")\n", + "dso = xr.open_dataset(dso_path)\n", + "dso" + ] + }, + { + "cell_type": "markdown", + "id": "63ca9978", + "metadata": {}, + "source": [ + "#### Create the Grid object" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "462bba5d", + "metadata": {}, + "outputs": [], + "source": [ + "grido = clore.Grid(ds=dso)\n", + "grido" + ] + }, + { + "cell_type": "markdown", + "id": "4daaa5d9", + "metadata": {}, + "source": [ + "The `xarray.Dataset` is attached to the `clisops.core.Grid` object. Auxiliary coordinates and data variables have been (re)set appropriately." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4bb185c4", + "metadata": {}, + "outputs": [], + "source": [ + "grido.ds" + ] + }, + { + "cell_type": "markdown", + "id": "3a0a8abc", + "metadata": {}, + "source": [ + "#### Plot the data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cee07078", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(9,4))\n", + "ax = plt.axes(projection=ccrs.PlateCarree())\n", + "grido.ds.tos.isel(time=0).plot.pcolormesh(ax=ax, x=grido.lon, y=grido.lat, shading=\"auto\",\n", + " cmap=\"RdBu_r\", vmin = -1, vmax=40)\n", + "ax.coastlines()" + ] + }, + { + "cell_type": "markdown", + "id": "19d1681f", + "metadata": {}, + "source": [ + "### Create a grid object from an `xarray.DataArray`\n", + "\n", + "Note that `xarray.DataArray` objects do not support the bounds of coordinate variables to be defined.\n", + "\n", + "#### Extract tos `DataArray`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "809a9293", + "metadata": {}, + "outputs": [], + "source": [ + "dao = dso.tos\n", + "dao" + ] + }, + { + "cell_type": "markdown", + "id": "e7a5dde9", + "metadata": {}, + "source": [ + "#### Create Grid object for MPIOM tos dataarray:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5e98d6cf", + "metadata": {}, + "outputs": [], + "source": [ + "grido_tos = clore.Grid(ds=dao)\n", + "grido_tos" + ] + }, + { + "cell_type": "markdown", + "id": "3af96e5f", + "metadata": {}, + "source": [ + "### Create a grid object using a `grid_instructor`\n", + "\n", + "* global grid: `grid_instructor = (lon_step, lat_step)` or `grid_instructor = step`\n", + "* regional grid:`grid_instructor = (lon_start, lon_end, lon_step, lat_start, lat_end, lat_step)` or `grid_instructor = (start, end, step)` " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac51c540", + "metadata": {}, + "outputs": [], + "source": [ + "grid_1deg = clore.Grid(grid_instructor=1)\n", + "grid_1deg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a03f9f6e", + "metadata": {}, + "outputs": [], + "source": [ + "grid_1degx2deg_regional = clore.Grid(grid_instructor=(0., 90., 1., 35., 50., 2. ))\n", + "grid_1degx2deg_regional" + ] + }, + { + "cell_type": "markdown", + "id": "b2473fa8", + "metadata": {}, + "source": [ + "### Create a grid object using a `grid_id`\n", + "\n", + "Makes use of the predefined grids of `roocs_grids`, which is a collection of grids used for example for the [IPCC Atlas](https://github.com/IPCC-WG1/Atlas/tree/main/reference-grids) and for [CMIP6 Regridding Weights generation](https://docs.google.com/document/d/1BfVVsKAk9MAsOYstwFSWI2ZBt5mrO_Nmcu7rLGDuL08/edit)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "744ba37b", + "metadata": {}, + "outputs": [], + "source": [ + "for key, gridinfo in grid_annotations.items(): print(f\"- {key:20} {gridinfo}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "243cf573", + "metadata": {}, + "outputs": [], + "source": [ + "grid_era5 = clore.Grid(grid_id = \"0pt25deg_era5\")\n", + "grid_era5" + ] + }, + { + "cell_type": "markdown", + "id": "d436424c", + "metadata": {}, + "source": [ + "### `clisops.core.Grid` objects can be compared to one another\n", + "\n", + "Optional verbose output gives information on where the grids differ: lat, lon, lat_bnds, lon_bnds, mask?\n", + "\n", + "#### Compare the tos dataset to the tos dataarray" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45888a28", + "metadata": {}, + "outputs": [], + "source": [ + "comp = grido.compare_grid(grido_tos, verbose = True)\n", + "print(\"Grids are equal?\", comp)" + ] + }, + { + "cell_type": "markdown", + "id": "41bca466", + "metadata": {}, + "source": [ + "#### Compare both 0.25° ERA5 Grids" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9468e89", + "metadata": {}, + "outputs": [], + "source": [ + "# Create the Grid object\n", + "grid_era5_lsm = clore.Grid(grid_id = \"0pt25deg_era5_lsm\", compute_bounds=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07338d79", + "metadata": {}, + "outputs": [], + "source": [ + "# Compare\n", + "comp = grid_era5.compare_grid(grid_era5_lsm, verbose=True)\n", + "print(\"Grids are equal?\", comp)" + ] + }, + { + "cell_type": "markdown", + "id": "5f358d9f", + "metadata": {}, + "source": [ + "### Strip `clisops.core.Grid` objects of all `data_vars` and `coords` unrelated to the horizontal grid" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b55721a2", + "metadata": {}, + "outputs": [], + "source": [ + "grid_era5_lsm.ds" + ] + }, + { + "cell_type": "markdown", + "id": "84b24eb6", + "metadata": {}, + "source": [ + "The parameter `keep_attrs` can be set, the default is `False`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3592a2ab", + "metadata": {}, + "outputs": [], + "source": [ + "grid_era5_lsm._drop_vars(keep_attrs=False)\n", + "grid_era5_lsm.ds" + ] + }, + { + "cell_type": "markdown", + "id": "1f198140", + "metadata": {}, + "source": [ + "### Transfer coordinate variables between `clisops.core.Grid` objects that are unrelated to the horizontal grid\n", + "\n", + "The parameter `keep_attrs` can be set, the default is `True`. All settings for `keep_attrs` are described later in section [clisops.core.regrid](#clisops.core.regrid).\n", + "\n", + "#### Load the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c138fdac", + "metadata": {}, + "outputs": [], + "source": [ + "ds_vert_path = Path(MINIESGF, \"badc/cmip6/data/CMIP6/CMIP/MPI-M/MPI-ESM1-2-LR/historical/r1i1p1f1/\"\n", + " \"AERmon/o3/gn/v20190710/o3_AERmon_MPI-ESM1-2-LR_historical_r1i1p1f1_gn_185001.nc\")\n", + "ds_vert = xr.open_dataset(ds_vert_path)\n", + "ds_vert" + ] + }, + { + "cell_type": "markdown", + "id": "b7a6c965", + "metadata": {}, + "source": [ + "#### Create grid object" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7aae205d", + "metadata": {}, + "outputs": [], + "source": [ + "grid_vert = clore.Grid(ds_vert)\n", + "grid_vert" + ] + }, + { + "cell_type": "markdown", + "id": "331f439e", + "metadata": {}, + "source": [ + "#### Transfer the coordinates to the ERA5 grid object" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dab5a30d", + "metadata": {}, + "outputs": [], + "source": [ + "grid_era5_lsm._transfer_coords(grid_vert, keep_attrs=True)\n", + "grid_era5_lsm.ds" + ] + }, + { + "cell_type": "markdown", + "id": "e0250662", + "metadata": {}, + "source": [ + "\n", + "\n", + "## `clisops.core.Weights`\n", + "\n", + "Create regridding weights to regrid between two grids. Supported are the following of [xESMF's remapping methods](https://pangeo-xesmf.readthedocs.io/en/latest/notebooks/Compare_algorithms.html):\n", + "* `nearest_s2d`\n", + "* `bilinear`\n", + "* `conservative`\n", + "* `patch`" + ] + }, + { + "cell_type": "markdown", + "id": "37266736", + "metadata": {}, + "source": [ + "### Create 2-degree target grid" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c11b864", + "metadata": {}, + "outputs": [], + "source": [ + "grid_2deg = clore.Grid(grid_id=\"2deg_lsm\", compute_bounds=True)\n", + "grid_2deg" + ] + }, + { + "cell_type": "markdown", + "id": "3e912197", + "metadata": {}, + "source": [ + "### Create conservative remapping weights using the `clisops.core.Weights` class\n", + "`grid_in` and `grid_out` are `Grid` objects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "67480170", + "metadata": {}, + "outputs": [], + "source": [ + "%time weights = clore.Weights(grid_in = grido, grid_out = grid_2deg, method=\"conservative\")" + ] + }, + { + "cell_type": "markdown", + "id": "db9fdde3", + "metadata": {}, + "source": [ + "### Local weights cache\n", + "\n", + "Weights are cached on disk and do not have to be created more than once. The default cache directory is `/tmp/clisops_grid_weights` and should be adjusted either in the `roocs.ini` configuration file that can be found in the clisops installation directory or via:\n", + "```python\n", + "from clisops import core as clore\n", + "clore.weights_cache_init(\"/dir/for/weights/cache\")\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b21f10e", + "metadata": {}, + "outputs": [], + "source": [ + "!ls -sh /tmp/clisops_grid_weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b762418f", + "metadata": {}, + "outputs": [], + "source": [ + "!cat /tmp/clisops_grid_weights/weights_*_conservative.json" + ] + }, + { + "cell_type": "markdown", + "id": "46ed2631", + "metadata": {}, + "source": [ + "Now the weights will be read directly from the cache" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "734b0888", + "metadata": {}, + "outputs": [], + "source": [ + "%time weights = clore.Weights(grid_in = grido, grid_out = grid_2deg, method=\"conservative\")" + ] + }, + { + "cell_type": "markdown", + "id": "83e7a856", + "metadata": {}, + "source": [ + "The weights cache can be flushed, which removes all weight and grid files as well as the json files holding the metadata. To see what would be removed, one can use the `dryrun=True` parameter. To re-initialize the weights cache in a different directory, one can use the `weights_dir_init=\"/new/dir/for/weights/cache\"` parameter. Even when re-initializing the weights cache under a new path, using `clore.weights_cache_flush`, no directory is getting removed, only above listed files. When `dryrun` is not set, the files that are getting deleted can be displayed with `verbose=True`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c525707", + "metadata": {}, + "outputs": [], + "source": [ + "clore.weights_cache_flush(dryrun=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6fde9539", + "metadata": {}, + "outputs": [], + "source": [ + "clore.weights_cache_flush(verbose=True)" + ] + }, + { + "cell_type": "markdown", + "id": "e66a563f", + "metadata": {}, + "source": [ + "\n", + "\n", + "## `clisops.core.regrid`\n", + "\n", + "This function allows to perform the eventual regridding and provides a resulting `xarray.Dataset`\n", + "\n", + "```python\n", + "def regrid(\n", + " grid_in: Grid,\n", + " grid_out: Grid,\n", + " weights: Weights,\n", + " adaptive_masking_threshold: Optional[float] = 0.5,\n", + " keep_attrs: Optional[bool] = True,\n", + "):\n", + "```\n", + "\n", + "* `grid_in` and `grid_out` are `Grid` objects, `weights` is a `Weights` object.\n", + "* `adaptive_masking_threshold` (AMT) A value within the [0., 1.] interval that defines the maximum `RATIO` of missing_values amongst the total number of data values contributing to the calculation of the target grid cell value. For a fraction [0., AMT[ of the contributing source data missing, the target grid cell will be set to missing_value, else, it will be re-normalized by the factor `1./(1.-RATIO)`. Thus, if AMT is set to 1, all source grid cells that contribute to a target grid cell must be missing in order for the target grid cell to be defined as missing itself. Values greater than 1 or less than 0 will cause adaptive masking to be turned off. This adaptive masking technique allows to reuse generated weights for differently masked data (e.g. land-sea masks or orographic masks that vary with depth / height).\n", + "* `keep_attrs` can have the following settings:\n", + " * `True` : The resulting `xarray.Dataset` will have all attributes of `grid_in.ds.attrs`, despite attributes that have to be added and altered due to the new grid. \n", + " * `False` : The resulting `xarray.Dataset` will have no attributes despite attributes generated by the regridding process.\n", + " * `\"target\"` : The resulting `xarray.Dataset` will have all attributes of `grid_out.ds.attrs`, despite attributes generated by the regridding process. Not recommended.\n", + " \n", + " \n", + "### In the following an example showing the function application and the effect of the adaptive masking." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1f50b0f", + "metadata": {}, + "outputs": [], + "source": [ + "ds_out_amt0 = clore.regrid(grido, grid_2deg, weights, adaptive_masking_threshold=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb325c4c", + "metadata": {}, + "outputs": [], + "source": [ + "ds_out_amt1 = clore.regrid(grido, grid_2deg, weights, adaptive_masking_threshold=0.5)" + ] + }, + { + "cell_type": "markdown", + "id": "5c8ba0ae", + "metadata": {}, + "source": [ + "#### Plot the resulting data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a220b89", + "metadata": {}, + "outputs": [], + "source": [ + "# Create panel plot of regridded data (global)\n", + "fig, axes = plt.subplots(ncols=2, nrows=1, \n", + " figsize=(18, 5), # global\n", + " subplot_kw={'projection': ccrs.PlateCarree()})\n", + "\n", + "ds_out_amt0[\"tos\"].isel(time=0).plot.pcolormesh(ax=axes[0], vmin=0, vmax=30, cmap=\"plasma\")\n", + "axes[0].title.set_text(\"Target (2° regular lat-lon) - No adaptive masking\")\n", + "\n", + "ds_out_amt1[\"tos\"].isel(time=0).plot.pcolormesh(ax=axes[1], vmin=0, vmax=30, cmap=\"plasma\")\n", + "axes[1].title.set_text(\"Target (2° regular lat-lon) - Adaptive masking\")\n", + "\n", + "for axis in axes.flatten():\n", + " axis.coastlines()\n", + " axis.set_xlabel('lon')\n", + " axis.set_ylabel('lat')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ab66990", + "metadata": {}, + "outputs": [], + "source": [ + "# Create panel plot of regridded data (Japan)\n", + "fig, axes = plt.subplots(ncols=3, nrows=1, \n", + " figsize=(18, 4), # Japan\n", + " subplot_kw={'projection': ccrs.PlateCarree()})\n", + "\n", + "grido.ds.tos.isel(time=0).plot.pcolormesh(ax=axes[0], x=grido.lon, y=grido.lat, \n", + " vmin=0, vmax=30, cmap=\"plasma\", shading=\"auto\")\n", + "axes[0].title.set_text(\"Source - MPI-ESM1-2-HR MPIOM (TP04, ~0.4° resolution)\")\n", + "\n", + "ds_out_amt0[\"tos\"].isel(time=0).plot.pcolormesh(ax=axes[1], vmin=0, vmax=30, cmap=\"plasma\")\n", + "axes[1].title.set_text(\"Target - No adaptive masking\")\n", + "\n", + "ds_out_amt1[\"tos\"].isel(time=0).plot.pcolormesh(ax=axes[2], vmin=0, vmax=30, cmap=\"plasma\")\n", + "axes[2].title.set_text(\"Target - Adaptive masking\")\n", + "\n", + "for axis in axes.flatten():\n", + " axis.coastlines()\n", + " axis.set_xlabel('lon')\n", + " axis.set_ylabel('lat')\n", + " axis.set_xlim([125, 150])\n", + " axis.set_ylim([25, 50])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/environment.yml b/environment.yml index 5b5fdd3c..5edd8bf4 100644 --- a/environment.yml +++ b/environment.yml @@ -3,29 +3,27 @@ channels: - conda-forge dependencies: - python >=3.8,<3.12 - - pip - flit - bottleneck >=1.3.1 - - cf_xarray >=0.7.0 + - cf_xarray >=0.8.5 - cftime >=1.4.1 - dask >=2.6.0 - gdal >=3.0 - geopandas >=0.11 - loguru >=0.5.3 - netCDF4 >=1.4 - - numba # needed for xesmf v0.6.3, see: https://github.com/conda-forge/xesmf-feedstock/pull/24 - numpy >=1.16 - packaging - pandas >=1.0.3 - pooch - poppler >=0.67 - pyproj >=3.3.0 - - requests>=2.0 - - roocs-utils>=0.6.4,<0.7 + - requests >=2.0 + - roocs-grids>=0.1.2 + - roocs-utils >=0.6.4,<0.7 - shapely >=1.9 - - sparse >=0.8.0 # needed for xesmf v0.6.3, see: https://github.com/conda-forge/xesmf-feedstock/pull/24 - - xarray >=0.21 - - xesmf >=0.6.3 + - xarray >=0.21,<2023.3.0 # https://github.com/pydata/xarray/issues/7794 + - xesmf >=0.8.2 # Documentation - ipykernel - ipython @@ -35,4 +33,9 @@ dependencies: - nbsphinx - pandoc - sphinx - - sphinx-rtd-theme + - sphinx-rtd-theme >=1.0 + # Upstream + - pip +# - pip: +# - cf-xarray @ git+https://github.com/xarray-contrib/cf-xarray/@main#egg=cf-xarray +# - roocs-utils @ git+https://github.com/roocs/roocs-utils.git@master#egg=roocs-utils diff --git a/notebooks b/notebooks index 5823f0ca..88458cdf 120000 --- a/notebooks +++ b/notebooks @@ -1 +1 @@ -/home/tjs/git/clisops/docs/notebooks \ No newline at end of file +docs/notebooks \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d71b3c78..19dcdb10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,8 @@ dynamic = ["description", "version"] dependencies = [ "bottleneck>=1.3.1", # cf-xarray is differently named on conda-forge - "cf-xarray>=0.7.0", + "cf-xarray>=0.8.6;python_version>='3.9'", + "cf-xarray>=0.7.5,<=0.8.0;python_version=='3.8'", "cftime>=1.4.1", "dask[complete]>=2.6", "geopandas>=0.11", @@ -60,9 +61,11 @@ dependencies = [ "pooch", "pyproj>=3.3.0", "requests>=2.0", + # roocs_grids is differently named on conda-forge + "roocs_grids>=0.1.2", "roocs-utils>=0.6.4,<0.7", "shapely>=1.9", - "xarray>=0.21", + "xarray>=0.21,<2023.3.0", ] [project.optional-dependencies] @@ -70,7 +73,7 @@ dev = [ # Packaging "flit", # Dev tools and testing - "black>=23.7.0", + "black>=23.11.0", "bump2version", "flake8", "gitpython>=3.1.30", @@ -94,7 +97,7 @@ docs = [ "sphinx", "sphinx-rtd-theme>=1.0", ] -extra = ["xesmf>=0.6.2"] +extra = ["xesmf>=0.8.2"] [project.urls] "Homepage" = "https://clisops.readthedocs.io/" @@ -115,6 +118,7 @@ target-version = [ [tool.coverage.run] relative_files = true +include = ["clisops/*"] omit = ["*/tests/*.py"] [tool.flit.sdist] @@ -157,6 +161,7 @@ profile = "black" py_version = 38 append_only = true known_first_party = "clisops,_common" +skip = ["clisops/core/__init__.py"] [tool.pytest.ini_options] addopts = [ diff --git a/requirements_upstream.txt b/requirements_upstream.txt index 8305dff0..213ec32a 100644 --- a/requirements_upstream.txt +++ b/requirements_upstream.txt @@ -2,5 +2,6 @@ bottleneck @ git+https://github.com/pydata/bottleneck.git@master cf-xarray @ git+https://github.com/xarray-contrib/cf-xarray.git@main cftime @ git+https://github.com/Unidata/cftime.git@master flox @ git+https://github.com/xarray-contrib/flox.git@main +roocs-grids @ git+https://github.com/roocs/roocs-grids.git@main roocs-utils @ git+https://github.com/roocs/roocs-utils.git@master xarray @ git+https://github.com/pydata/xarray.git@main diff --git a/setup.cfg b/setup.cfg index cc4f1d26..81e7a941 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,7 +12,7 @@ search = version = "{current_version}" replace = version = "{new_version}" [flake8] -exclude = +exclude = .git, docs, build, @@ -20,7 +20,7 @@ exclude = tests/mini-esgf-data max-line-length = 88 max-complexity = 12 -ignore = +ignore = C901 E203 E231 @@ -29,9 +29,9 @@ ignore = E501 F401 F403 + F405 W503 W504 - F405 [aliases] test = pytest diff --git a/tests/_common.py b/tests/_common.py index b891a0c6..11eaab7e 100644 --- a/tests/_common.py +++ b/tests/_common.py @@ -221,22 +221,50 @@ def cmip6_archive_base(): "master/test_data/badc/cmip6/data/CMIP6/CMIP/MPI-M/MPI-ESM1-2-LR/historical/r1i1p1f1/Omon/tos/gn/v20190710/tos_Omon_MPI-ESM1-2-LR_historical_r1i1p1f1_gn_185001-186912.nc", ).as_posix() +# test daatsets used for regridding tests - one time step, full lat/lon +# cmip6 atmosphere grid +# x-res = 2.0 CMIP6_TAS_ONE_TIME_STEP = Path( MINI_ESGF_CACHE_DIR, "master/test_data/badc/cmip6/data/CMIP6/CMIP/CAS/FGOALS-g3/historical/r1i1p1f1/Amon/tas/gn/v20190818/tas_Amon_FGOALS-g3_historical_r1i1p1f1_gn_185001.nc", ).as_posix() +# cmip6 ocean grids +# x-res = 0.90 CMIP6_TOS_ONE_TIME_STEP = Path( MINI_ESGF_CACHE_DIR, "master/test_data/badc/cmip6/data/CMIP6/CMIP/MPI-M/MPI-ESM1-2-HR/historical/r1i1p1f1/Omon/tos/gn/v20190710/tos_Omon_MPI-ESM1-2-HR_historical_r1i1p1f1_gn_185001.nc", ).as_posix() +# cmip5 regular gridded one time step +CMIP5_MRSOS_ONE_TIME_STEP = Path( + MINI_ESGF_CACHE_DIR, + "master/test_data/badc/cmip5/data/cmip5/output1/MOHC/HadGEM2-ES/rcp85/day/land/day/r1i1p1/latest/mrsos/mrsos_day_HadGEM2-ES_rcp85_r1i1p1_20051201.nc", +).as_posix() + # CMIP6 dataset with weird range in its longitude coordinate (-300, 60) CMIP6_GFDL_EXTENT = Path( MINI_ESGF_CACHE_DIR, "master/test_data/badc/cmip6/data/CMIP6/CMIP/NOAA-GFDL/GFDL-CM4/historical/r1i1p1f1/Omon/sos/gn/v20180701/sos_Omon_GFDL-CM4_historical_r1i1p1f1_gn_185001.nc", ).as_posix() +# CMIP6 two files with different precision in their coordinate variables +CMIP6_TAS_PRECISION_A = Path( + MINI_ESGF_CACHE_DIR, + "master/test_data/badc/cmip6/data/CMIP6/CMIP/AWI/AWI-ESM-1-1-LR/1pctCO2/r1i1p1f1/Amon/tas/gn/v20200212/tas_Amon_AWI-ESM-1-1-LR_1pctCO2_r1i1p1f1_gn_185501.nc", +).as_posix() + +CMIP6_TAS_PRECISION_B = Path( + MINI_ESGF_CACHE_DIR, + "master/test_data/badc/cmip6/data/CMIP6/CMIP/AWI/AWI-ESM-1-1-LR/1pctCO2/r1i1p1f1/Amon/tas/gn/v20200212/tas_Amon_AWI-ESM-1-1-LR_1pctCO2_r1i1p1f1_gn_209901.nc", +).as_posix() + +# CMIP6 dataset with vertical axis +CMIP6_ATM_VERT_ONE_TIMESTEP = Path( + MINI_ESGF_CACHE_DIR, + "master/test_data/badc/cmip6/data/CMIP6/CMIP/MPI-M/MPI-ESM1-2-LR/historical/r1i1p1f1/AERmon/o3/gn/v20190710/o3_AERmon_MPI-ESM1-2-LR_historical_r1i1p1f1_gn_185001.nc", +).as_posix() + # CMIP6 2nd dataset with weird range in its longitude coordinate (-280, 80) CMIP6_IITM_EXTENT = Path( MINI_ESGF_CACHE_DIR, @@ -248,6 +276,7 @@ def cmip6_archive_base(): "master/test_data/badc/cmip6/data/CMIP6/CMIP/CNRM-CERFACS/CNRM-CM6-1-HR/historical/r1i1p1f2/Omon/tos/gn/v20191021/tos_Omon_CNRM-CM6-1-HR_historical_r1i1p1f2_gn_185001.nc", ).as_posix() +# CMIP6 3 datasets with unstructured grids, one of which has a vertical dimension CMIP6_UNSTR_FESOM_LR = Path( MINI_ESGF_CACHE_DIR, "master/test_data/badc/cmip6/data/CMIP6/CMIP/AWI/AWI-ESM-1-1-LR/historical/r1i1p1f1/Omon/tos/gn/v20200212/tos_Omon_AWI-ESM-1-1-LR_historical_r1i1p1f1_gn_185001.nc", @@ -258,7 +287,36 @@ def cmip6_archive_base(): "master/test_data/badc/cmip6/data/CMIP6/CMIP/MPI-M/ICON-ESM-LR/historical/r1i1p1f1/Amon/tas/gn/v20210215/tas_Amon_ICON-ESM-LR_historical_r1i1p1f1_gn_185001.nc", ).as_posix() +CMIP6_UNSTR_VERT_ICON_O = Path( + MINI_ESGF_CACHE_DIR, + "master/test_data/badc/cmip6/data/CMIP6/CMIP/MPI-M/ICON-ESM-LR/historical/r1i1p1f1/Omon/thetao/gn/v20210215/thetao_Omon_ICON-ESM-LR_historical_r1i1p1f1_gn_185001.nc", +).as_posix() + +# CMIP6 dataset with missing values in the auxiliary coordinate variables, but no corresponding _FillValue or missing_value attributes set +CMIP6_UNTAGGED_MISSVALS = Path( + MINI_ESGF_CACHE_DIR, + "master/test_data/badc/cmip6/data/CMIP6/CMIP/NCAR/CESM2-FV2/historical/r1i1p1f1/Omon/tos/gn/v20191120/tos_Omon_CESM2-FV2_historical_r1i1p1f1_gn_200001.nc", +).as_posix() + +# CMIP6 datasets defined on a staggered grid (u and v component) +CMIP6_STAGGERED_UCOMP = Path( + MINI_ESGF_CACHE_DIR, + "master/test_data/badc/cmip6/data/CMIP6/CMIP/MPI-M/MPI-ESM1-2-LR/historical/r1i1p1f1/Omon/tauuo/gn/v20200909/tauuo_Omon_MPI-ESM1-2-LR_historical_r1i1p1f1_gn_185001.nc", +).as_posix() + +CMIP6_STAGGERED_VCOMP = Path( + MINI_ESGF_CACHE_DIR, + "master/test_data/badc/cmip6/data/CMIP6/CMIP/MPI-M/MPI-ESM1-2-LR/historical/r1i1p1f1/Omon/tauvo/gn/v20190710/tauvo_Omon_MPI-ESM1-2-LR_historical_r1i1p1f1_gn_185001.nc ", +).as_posix() + +# CORDEX dataset on regional curvilinear grid CORDEX_TAS_ONE_TIMESTEP = Path( MINI_ESGF_CACHE_DIR, "master/test_data/pool/data/CORDEX/data/cordex/output/EUR-22/GERICS/MPI-M-MPI-ESM-LR/rcp85/r1i1p1/GERICS-REMO2015/v1/mon/tas/v20191029/tas_EUR-22_MPI-M-MPI-ESM-LR_rcp85_r1i1p1_GERICS-REMO2015_v1_mon_202101.nc", ).as_posix() + +# CORDEX dataset on regional curvilinear grid without defined bounds +CORDEX_TAS_NO_BOUNDS = Path( + MINI_ESGF_CACHE_DIR, + "master/test_data/pool/data/CORDEX/data/cordex/output/EUR-11/KNMI/MPI-M-MPI-ESM-LR/rcp85/r1i1p1/KNMI-RACMO22E/v1/mon/tas/v20190625/tas_EUR-11_MPI-M-MPI-ESM-LR_rcp85_r1i1p1_KNMI-RACMO22E_v1_mon_209101.nc", +).as_posix() diff --git a/tests/test_config.py b/tests/test_config.py index a95ca3a4..33470a39 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,6 +6,10 @@ def test_local_config_loads(): assert "clisops:read" in CONFIG assert "file_size_limit" in CONFIG["clisops:write"] + assert "clisops:grid_weights" in CONFIG + assert "local_weights_dir" in CONFIG["clisops:grid_weights"] + assert "clisops:coordinate_precision" in CONFIG + assert "hor_coord_decimals" in CONFIG["clisops:coordinate_precision"] def test_dask_env_variables(): diff --git a/tests/test_core_average.py b/tests/test_core_average.py index 1331d0a8..a5ad0318 100644 --- a/tests/test_core_average.py +++ b/tests/test_core_average.py @@ -4,7 +4,7 @@ import numpy as np import pytest import xarray as xr -from pkg_resources import parse_version +from packaging.version import Version from roocs_utils.exceptions import InvalidParameterValue from roocs_utils.xarray_utils import xarray_utils as xu @@ -15,7 +15,7 @@ try: import xesmf - if parse_version(xesmf.__version__) < parse_version("0.6.2"): + if Version(xesmf.__version__) < Version("0.6.2"): raise ImportError() except ImportError: xesmf = None diff --git a/tests/test_core_regrid.py b/tests/test_core_regrid.py new file mode 100644 index 00000000..ff23d6c9 --- /dev/null +++ b/tests/test_core_regrid.py @@ -0,0 +1,1121 @@ +import os +from glob import glob +from pathlib import Path + +import cf_xarray # noqa +import numpy as np +import pytest +import xarray as xr +from packaging.version import Version +from roocs_grids import get_grid_file + +import clisops.utils.dataset_utils as clidu +from _common import ( + CMIP6_ATM_VERT_ONE_TIMESTEP, + CMIP6_GFDL_EXTENT, + CMIP6_OCE_HALO_CNRM, + CMIP6_STAGGERED_UCOMP, + CMIP6_TAS_ONE_TIME_STEP, + CMIP6_TAS_PRECISION_A, + CMIP6_TAS_PRECISION_B, + CMIP6_TOS_ONE_TIME_STEP, + CMIP6_UNSTR_ICON_A, + CORDEX_TAS_NO_BOUNDS, +) +from clisops import CONFIG +from clisops.core.regrid import ( + XESMF_MINIMUM_VERSION, + Grid, + Weights, + regrid, + weights_cache_flush, + weights_cache_init, +) +from clisops.ops.subset import subset +from clisops.utils.output_utils import FileLock + +try: + import xesmf + + if Version(xesmf.__version__) < Version(XESMF_MINIMUM_VERSION): + raise ImportError +except ImportError: + xesmf = None + + +# test from grid_id --predetermined +# test different types of grid e.g. unstructured, not supported type +# test for errors e.g. +# no lat/lon in dataset +# more than one latitude/longitude +# grid instructor tuple not correct length + + +XESMF_IMPORT_MSG = ( + f"xESMF >= {XESMF_MINIMUM_VERSION} is needed for regridding functionalities." +) + + +def test_grid_init_ds_tas_regular(load_esgf_test_data): + ds = xr.open_dataset(CMIP6_TAS_ONE_TIME_STEP, use_cftime=True) + grid = Grid(ds=ds) + + assert grid.format == "CF" + assert grid.source == "Dataset" + assert grid.lat == ds.lat.name + assert grid.lon == ds.lon.name + assert grid.type == "regular_lat_lon" + assert grid.extent == "global" + assert not grid.contains_collapsed_cells + assert not grid.contains_duplicated_cells + assert grid.lat_bnds == ds.lat_bnds.name + assert grid.lon_bnds == ds.lon_bnds.name + assert grid.nlat == 80 + assert grid.nlon == 180 + assert grid.ncells == 14400 + + # not implemented yet + # assert self.mask + + +def test_grid_init_da_tas_regular(load_esgf_test_data): + ds = xr.open_dataset(CMIP6_TAS_ONE_TIME_STEP, use_cftime=True) + da = ds.tas + grid = Grid(ds=da) + + assert grid.format == "CF" + assert grid.source == "Dataset" + assert grid.lat == da.lat.name + assert grid.lon == da.lon.name + assert grid.type == "regular_lat_lon" + assert grid.extent == "global" + assert grid.contains_collapsed_cells is None + assert grid.contains_duplicated_cells is False + assert grid.lat_bnds is None + assert grid.lon_bnds is None + assert grid.nlat == 80 + assert grid.nlon == 180 + assert grid.ncells == 14400 + + +def test_grid_init_ds_tos_curvilinear(load_esgf_test_data): + ds = xr.open_dataset(CMIP6_TOS_ONE_TIME_STEP, use_cftime=True) + grid = Grid(ds=ds) + + assert grid.format == "CF" + assert grid.source == "Dataset" + assert grid.lat == ds.latitude.name + assert grid.lon == ds.longitude.name + assert grid.type == "curvilinear" + assert grid.extent == "global" + assert grid.lat_bnds == "vertices_latitude" + assert grid.lon_bnds == "vertices_longitude" + assert grid.contains_collapsed_cells + assert grid.contains_duplicated_cells + assert grid.nlat == 404 # 402 w/o halo # this is number of 'j's + assert grid.nlon == 802 # 800 w/o halo # this is the number of 'i's + assert grid.ncells == 324008 # 321600 w/o halo + + # not implemented yet + # assert self.mask + + +def test_grid_init_ds_tas_unstructured(load_esgf_test_data): + ds = xr.open_dataset(CMIP6_UNSTR_ICON_A, use_cftime=True) + grid = Grid(ds=ds) + + assert grid.format == "CF" + assert grid.source == "Dataset" + assert grid.lat == ds.latitude.name + assert grid.lon == ds.longitude.name + assert grid.type == "unstructured" + assert grid.extent == "global" + assert not grid.contains_collapsed_cells + assert not grid.contains_duplicated_cells + assert grid.lat_bnds == "latitude_bnds" + assert grid.lon_bnds == "longitude_bnds" + assert grid.ncells == 20480 + print(grid.contains_collapsed_cells) + + # not implemented yet + # assert self.mask + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_grid_instructor_global(): + grid_instructor = (1.5, 1.5) + grid = Grid(grid_instructor=grid_instructor) + + assert grid.format == "CF" + assert grid.source == "xESMF" + assert grid.lat == "lat" + assert grid.lon == "lon" + assert grid.type == "regular_lat_lon" + assert grid.extent == "global" + assert not grid.contains_collapsed_cells + assert not grid.contains_duplicated_cells + + # check that grid_from_instructor sets the format to xESMF + grid._grid_from_instructor(grid_instructor) + assert grid.format == "xESMF" + + assert grid.lat_bnds == "lat_bnds" + assert grid.lon_bnds == "lon_bnds" + assert grid.nlat == 120 + assert grid.nlon == 240 + assert grid.ncells == 28800 + + # not implemented yet + # assert self.mask + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_grid_instructor_2d_regional_change_lon(): + grid_instructor = (50, 240, 1.5, -90, 90, 1.5) + grid = Grid(grid_instructor=grid_instructor) + + assert grid.format == "CF" + assert grid.source == "xESMF" + assert grid.lat == "lat" + assert grid.lon == "lon" + assert grid.type == "regular_lat_lon" + assert grid.extent == "regional" + assert not grid.contains_collapsed_cells + assert not grid.contains_duplicated_cells + + # check that grid_from_instructor sets the format to xESMF + grid._grid_from_instructor(grid_instructor) + assert grid.format == "xESMF" + + assert grid.lat_bnds == "lat_bnds" + assert grid.lon_bnds == "lon_bnds" + assert grid.nlat == 120 + assert grid.nlon == 127 + assert grid.ncells == 15240 + + # not implemented yet + # assert self.mask + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_grid_instructor_2d_regional_change_lat(): + grid_instructor = (0, 360, 1.5, -60, 50, 1.5) + grid = Grid(grid_instructor=grid_instructor) + + assert grid.format == "CF" + assert grid.source == "xESMF" + assert grid.lat == "lat" + assert grid.lon == "lon" + assert grid.type == "regular_lat_lon" + + # Extent in y-direction ignored, as not of importance + # for xesmf.Regridder. Extent in x-direction should be + # detected as "global" + assert grid.extent == "global" + + assert not grid.contains_collapsed_cells + assert not grid.contains_duplicated_cells + + assert grid.lat_bnds == "lat_bnds" + assert grid.lon_bnds == "lon_bnds" + assert grid.nlat == 73 + assert grid.nlon == 240 + assert grid.ncells == 17520 + + # not implemented yet + # assert self.mask + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_grid_instructor_2d_regional_change_lon_and_lat(): + grid_instructor = (50, 240, 1.5, -60, 50, 1.5) + grid = Grid(grid_instructor=grid_instructor) + + assert grid.format == "CF" + assert grid.source == "xESMF" + assert grid.lat == "lat" + assert grid.lon == "lon" + assert grid.type == "regular_lat_lon" + assert grid.extent == "regional" + assert not grid.contains_collapsed_cells + assert not grid.contains_duplicated_cells + + # check that grid_from_instructor sets the format to xESMF + grid._grid_from_instructor(grid_instructor) + assert grid.format == "xESMF" + + assert grid.lat_bnds == "lat_bnds" + assert grid.lon_bnds == "lon_bnds" + assert grid.nlat == 73 + assert grid.nlon == 127 + assert grid.ncells == 9271 + + # not implemented yet + # assert self.mask + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_grid_instructor_2d_global(): + grid_instructor = (0, 360, 1.5, -90, 90, 1.5) + grid = Grid(grid_instructor=grid_instructor) + + assert grid.format == "CF" + assert grid.source == "xESMF" + assert grid.lat == "lat" + assert grid.lon == "lon" + assert grid.type == "regular_lat_lon" + assert grid.extent == "global" + assert not grid.contains_collapsed_cells + assert not grid.contains_duplicated_cells + + # check that grid_from_instructor sets the format to xESMF + grid._grid_from_instructor(grid_instructor) + assert grid.format == "xESMF" + + assert grid.lat_bnds == "lat_bnds" + assert grid.lon_bnds == "lon_bnds" + assert grid.nlat == 120 + assert grid.nlon == 240 + assert grid.ncells == 28800 + + # not implemented yet + # assert self.mask + + +def test_from_grid_id(): + "Test to create grid from grid_id" + grid = Grid(grid_id="ERA-40") + + assert grid.format == "CF" + assert grid.source == "Predefined_ERA-40" + assert grid.lat == "lat" + assert grid.lon == "lon" + assert grid.type == "regular_lat_lon" + assert grid.extent == "global" + assert not grid.contains_collapsed_cells + assert not grid.contains_duplicated_cells + assert grid.lat_bnds == "lat_bnds" + assert grid.lon_bnds == "lon_bnds" + assert grid.nlat == 145 + assert grid.nlon == 288 + assert grid.ncells == 41760 + + # not implemented yet + # assert self.mask0 + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_grid_from_ds_adaptive_extent(load_esgf_test_data): + "Test that the extent is evaluated as global for original and derived adaptive grid." + dsA = xr.open_dataset(CMIP6_TOS_ONE_TIME_STEP, use_cftime=True) + dsB = xr.open_dataset(CMIP6_TAS_ONE_TIME_STEP, use_cftime=True) + dsC = xr.open_dataset(CMIP6_UNSTR_ICON_A, use_cftime=True) + + gA = Grid(ds=dsA) + gB = Grid(ds=dsB) + gC = Grid(ds=dsC) + gAa = Grid(ds=dsA, grid_id="adaptive") + gBa = Grid(ds=dsB, grid_id="adaptive") + gCa = Grid(ds=dsC, grid_id="auto") + + assert gA.extent == "global" + assert gB.extent == "global" + assert gC.extent == "global" + assert gAa.extent == "global" + assert gBa.extent == "global" + assert gCa.extent == "global" + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_grid_from_ds_adaptive_reproducibility(): + "Test that the extent is evaluated as global for original and derived adaptive grid." + fpathA = get_grid_file("0pt25deg") + dsA = xr.open_dataset(fpathA, use_cftime=True) + fpathB = get_grid_file("1deg") + dsB = xr.open_dataset(fpathB, use_cftime=True) + + gAa = Grid(ds=dsA, grid_id="adaptive") + gA = Grid(grid_id="0pt25deg") + print(repr(gAa)) + print(repr(gA)) + print(gAa.ds.lon[715:735]) + print(gA.ds.lon[715:735]) + + gBa = Grid(ds=dsB, grid_id="adaptive") + gB = Grid(grid_id="1deg") + print(gBa.ds.lon[170:190]) + print(gB.ds.lon[170:190]) + print(repr(gBa)) + print(repr(gB)) + + assert gA.extent == "global" + assert gA.compare_grid(gAa) + assert gB.extent == "global" + assert gB.compare_grid(gBa) + + +# @pytest.mark.xfail +def test_detect_extent_shifted_lon_frame(load_esgf_test_data): + "Test whether the extent can be correctly inferred for a dataset with shifted longitude frame." + # Load dataset with longitude ranging from (-300, 60) + ds = xr.open_dataset(CMIP6_GFDL_EXTENT, use_cftime=True) + + # Convert the longitude frame to 0,360 (shall happen implicitly in the future) + ds, ll, lu = clidu.cf_convert_between_lon_frames(ds, (0, 360)) + assert (ll, lu) == (0, 360) + + # Create Grid object and assert zonal extent + g = Grid(ds=ds) + assert g.extent == "global" + + +def test_compare_grid_same_resolution(): + "Test that two grids of same resolution from different sources evaluate as the same grid" + ds025 = xr.open_dataset(get_grid_file("0pt25deg_era5")) + g025 = Grid(grid_id="0pt25deg_era5", compute_bounds=True) + g025_lsm = Grid(grid_id="0pt25deg_era5_lsm", compute_bounds=True) + + assert g025.compare_grid(g025_lsm) + assert g025.compare_grid(ds025) + assert g025_lsm.compare_grid(ds025) + + +def test_compare_grid_diff_in_precision(load_esgf_test_data): + "Test that the same grid stored with different precision is evaluated as the same grid" + dsA = xr.open_dataset(CMIP6_TAS_PRECISION_A, use_cftime=True) + dsB = xr.open_dataset(CMIP6_TAS_PRECISION_B, use_cftime=True) + + gA = Grid(ds=dsA) + gB = Grid(ds=dsB) + + assert gA.compare_grid(gB) + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_compare_grid_hash_dict_and_verbose(capfd): + "Test Grid.hash_dict keys and Grid.compare_grid verbose option" + gA = Grid(grid_instructor=(1.0, 0.5)) + gB = Grid(grid_instructor=(1.0,)) + is_equal = gA.compare_grid(gB, verbose=True) + stdout, stderr = capfd.readouterr() + + assert stderr == "" + assert stdout == "The two grids differ in their respective lat, lat_bnds.\n" + assert not is_equal + assert len(gA.hash_dict) == 5 + assert list(gA.hash_dict.keys()) == ["lat", "lon", "lat_bnds", "lon_bnds", "mask"] + + +def test_to_netcdf(load_esgf_test_data, tmp_path): + "Test if grid file is properly written to disk using to_netcdf method." + # Create Grid object + dsA = xr.open_dataset(CMIP6_TAS_PRECISION_A) + gA = Grid(ds=dsA) + + # Save to disk + outdir = Path(tmp_path, "grids") + outfile = "grid_test.nc" + gA.to_netcdf(folder=outdir, filename=outfile) + + # Read from disk - ensure outfile has been created and lockfile deleted + assert os.path.isfile(Path(outdir, outfile)) + assert len([os.path.basename(f) for f in glob(f"{outdir}/*")]) == 1 + dsB = xr.open_dataset(Path(outdir, outfile)) + gB = Grid(ds=dsB) + + # Ensure Grid attributes and ds attrs are the same + assert gA.compare_grid(gB) + assert gA.format == gB.format + assert gA.type == gB.type + assert gA.extent == gB.extent + assert gA.source == gB.source + assert gA.contains_collapsed_cells == gB.contains_collapsed_cells + assert sorted(list(gA.ds.attrs.keys()) + ["clisops"]) == sorted( + list(gB.ds.attrs.keys()) + ) + + # Ensure all variables have been deleted from the dataset + assert not list(gB.ds.data_vars) + assert sorted(list(gB.ds.coords)) == [gA.lat, gA.lat_bnds, gA.lon, gA.lon_bnds] + + # Ensure the non-CF-compliant attributes xarray commonly defines are not present: + assert "_FillValue" not in dsB[gB.lat_bnds].attrs.keys() + assert "_FillValue" not in dsB[gB.lon_bnds].attrs.keys() + assert "coordinates" not in dsB.attrs.keys() + + +def test_detect_collapsed_cells(load_esgf_test_data): + "Test that collapsed cells are properly identified" + dsA = xr.open_dataset(CMIP6_OCE_HALO_CNRM, use_cftime=True) + dsB = xr.open_dataset(CMIP6_TOS_ONE_TIME_STEP, use_cftime=True) + dsC = xr.open_dataset(CMIP6_TAS_ONE_TIME_STEP, use_cftime=True) + + gA = Grid(ds=dsA) + gB = Grid(ds=dsB) + gC = Grid(ds=dsC) + + assert gA.contains_collapsed_cells + assert gB.contains_collapsed_cells + assert not gC.contains_collapsed_cells + + +def test_detect_duplicated_cells(load_esgf_test_data): + "Test that collapsed cells are properly identified" + dsA = xr.open_dataset(CMIP6_OCE_HALO_CNRM, use_cftime=True) + dsB = xr.open_dataset(CMIP6_TOS_ONE_TIME_STEP, use_cftime=True) + dsC = xr.open_dataset(CMIP6_TAS_ONE_TIME_STEP, use_cftime=True) + + gA = Grid(ds=dsA) + gB = Grid(ds=dsB) + gC = Grid(ds=dsC) + + assert gA.contains_duplicated_cells + assert gB.contains_duplicated_cells + assert not gC.contains_duplicated_cells + + +def test_subsetted_grid(load_esgf_test_data): + ds = xr.open_dataset(CMIP6_TAS_ONE_TIME_STEP, use_cftime=True) + + area = (0.0, 10.0, 175.0, 90.0) + + ds = subset( + ds=ds, + area=area, + output_type="xarray", + )[0] + + grid = Grid(ds=ds) + + assert grid.format == "CF" + assert grid.source == "Dataset" + assert grid.lat == ds.lat.name + assert grid.lon == ds.lon.name + assert grid.type == "regular_lat_lon" + assert grid.extent == "regional" + assert not grid.contains_collapsed_cells + assert not grid.contains_duplicated_cells + + assert grid.lat_bnds == ds.lat_bnds.name + assert grid.lon_bnds == ds.lon_bnds.name + assert grid.nlat == 35 + assert grid.nlon == 88 + assert grid.ncells == 3080 + + # not implemented yet + # assert self.mask + + +def test_drop_vars_transfer_coords(load_esgf_test_data): + "Test for Grid methods drop_vars and transfer_coords" + ds = xr.open_dataset(CMIP6_ATM_VERT_ONE_TIMESTEP) + g = Grid(ds=ds) + gt = Grid(grid_id="0pt25deg_era5_lsm", compute_bounds=True) + assert sorted(list(g.ds.data_vars.keys())) == ["o3", "ps"] + assert list(gt.ds.data_vars.keys()) != [] + + gt._drop_vars() + assert gt.ds.attrs == {} + assert sorted(list(gt.ds.coords.keys())) == [ + "lat_bnds", + "latitude", + "lon_bnds", + "longitude", + ] + + gt._transfer_coords(g) + assert gt.ds.attrs["institution"] == "Max Planck Institute for Meteorology" + assert gt.ds.attrs["activity_id"] == "CMIP" + assert sorted(list(gt.ds.coords.keys())) == [ + "ap", + "ap_bnds", + "b", + "b_bnds", + "lat_bnds", + "latitude", + "lev", + "lev_bnds", + "lon_bnds", + "longitude", + "time", + "time_bnds", + ] + assert list(gt.ds.data_vars.keys()) == [] + + +def test_calculate_bounds_curvilinear(load_esgf_test_data): + "Test for bounds calculation for curvilinear grid" + ds = xr.open_dataset(CORDEX_TAS_NO_BOUNDS).isel( + {"rlat": range(10), "rlon": range(10)} + ) + g = Grid(ds=ds, compute_bounds=True) + assert g.lat_bnds is not None + assert g.lon_bnds is not None + + +def test_calculate_bounds_duplicated_cells(load_esgf_test_data): + "Test for bounds calculation for curvilinear grid" + ds = xr.open_dataset(CORDEX_TAS_NO_BOUNDS).isel( + {"rlat": range(10), "rlon": range(10)} + ) + + # create duplicated cells + ds["lat"][:, 0] = ds["lat"][:, 1] + ds["lon"][:, 0] = ds["lon"][:, 1] + + # assert raised exception + with pytest.raises( + Exception, + match="This grid contains duplicated cell centers. Therefore bounds cannot be computed.", + ): + Grid(ds=ds, compute_bounds=True) + + +def test_centers_within_bounds_curvilinear(load_esgf_test_data): + "Test for bounds calculation for curvilinear grid" + ds = xr.open_dataset(CORDEX_TAS_NO_BOUNDS).isel( + {"rlat": range(10), "rlon": range(10)} + ) + g = Grid(ds=ds, compute_bounds=True) + assert g.lat_bnds is not None + assert g.lon_bnds is not None + assert g.contains_collapsed_cells is False + + # Check that there are bounds values smaller and greater than the cell center values + ones = np.ones((g.nlat, g.nlon), dtype=int) + assert np.all( + ones + == xr.where( + np.sum(xr.where(g.ds[g.lat] >= g.ds[g.lat_bnds], 1, 0), -1) > 0, 1, 0 + ) + ) + assert np.all( + ones + == xr.where( + np.sum(xr.where(g.ds[g.lat] <= g.ds[g.lat_bnds], 1, 0), -1) > 0, 1, 0 + ) + ) + assert np.all( + ones + == xr.where( + np.sum(xr.where(g.ds[g.lon] >= g.ds[g.lon_bnds], 1, 0), -1) > 0, 1, 0 + ) + ) + assert np.all( + ones + == xr.where( + np.sum(xr.where(g.ds[g.lon] <= g.ds[g.lon_bnds], 1, 0), -1) > 0, 1, 0 + ) + ) + + +def test_centers_within_bounds_regular_lat_lon(): + "Test for bounds calculation of regular lat lon grid" + g = Grid(grid_id="0pt25deg_era5_lsm", compute_bounds=True) + assert g.lat_bnds is not None + assert g.lon_bnds is not None + assert bool(g.contains_collapsed_cells) is False + + # Check that there are bounds values smaller and greater than the cell center values + ones_lat = np.ones((g.nlat,), dtype=int) + ones_lon = np.ones((g.nlon,), dtype=int) + assert np.all( + ones_lat + == xr.where( + np.sum(xr.where(g.ds[g.lat] >= g.ds[g.lat_bnds], 1, 0), -1) > 0, 1, 0 + ) + ) + assert np.all( + ones_lat + == xr.where( + np.sum(xr.where(g.ds[g.lat] <= g.ds[g.lat_bnds], 1, 0), -1) > 0, 1, 0 + ) + ) + assert np.all( + ones_lon + == xr.where( + np.sum(xr.where(g.ds[g.lon] >= g.ds[g.lon_bnds], 1, 0), -1) > 0, 1, 0 + ) + ) + assert np.all( + ones_lon + == xr.where( + np.sum(xr.where(g.ds[g.lon] <= g.ds[g.lon_bnds], 1, 0), -1) > 0, 1, 0 + ) + ) + + +def test_data_vars_coords_reset_and_cfxr(load_esgf_test_data): + dsA = xr.open_dataset(CMIP6_ATM_VERT_ONE_TIMESTEP) + + # generate dummy areacella + areacella = xr.DataArray( + { + "dims": ("lat", "lon"), + "attrs": { + "standard_name": "cell_area", + "cell_methods": "area: sum", + }, + "data": np.ones(18432, dtype=np.float32).reshape((96, 192)), + } + ) + dsA.update({"areacella": areacella}) + dsB = xr.decode_cf(dsA, decode_coords="all") + + # Grid._set_data_vars_and_coords should (re)set coords appropriately + gA = Grid(ds=dsA) + gB = Grid(ds=dsB) + + # cf_xarray should be able to identify important attributes and present both datasets equally + assert gA.compare_grid(gB) + assert gA.ds.cf.cell_measures == gB.ds.cf.cell_measures + assert gA.ds.o3.cf.cell_measures == gB.ds.o3.cf.cell_measures + assert gA.ds.cf.formula_terms == gB.ds.cf.formula_terms + assert gA.ds.o3.cf.formula_terms == gB.ds.o3.cf.formula_terms + assert gA.ds.cf.bounds == gB.ds.cf.bounds + assert str(gA.ds.cf) == str(gB.ds.cf) + + +# test all methods +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +class TestWeights: + def test_grids_in_and_out_bilinear(self, tmp_path): + ds = xr.open_dataset(CMIP6_TAS_ONE_TIME_STEP, use_cftime=True) + grid_in = Grid(ds=ds) + + assert grid_in.extent == "global" + + grid_instructor_out = (0, 360, 1.5, -90, 90, 1.5) + grid_out = Grid(grid_instructor=grid_instructor_out) + + assert grid_out.extent == "global" + + weights_cache_init(Path(tmp_path, "weights")) + w = Weights(grid_in=grid_in, grid_out=grid_out, method="bilinear") + + assert w.method == "bilinear" + assert ( + w.id + == "8edb4ee828dbebc2dc8e193281114093_bf73249f1725126ad3577727f3652019_peri_no-degen_bilinear" + ) + assert w.periodic + assert w.id in w.filename + assert "xESMF_v" in w.tool + assert w.format == "xESMF" + + # default file_name = method_inputgrid_outputgrid_periodic" + assert w.regridder.filename == "bilinear_80x180_120x240_peri.nc" + + def test_grids_in_and_out_conservative(self, tmp_path): + ds = xr.open_dataset(CMIP6_TAS_ONE_TIME_STEP, use_cftime=True) + grid_in = Grid(ds=ds) + + assert grid_in.extent == "global" + + grid_instructor_out = (0, 360, 1.5, -90, 90, 1.5) + grid_out = Grid(grid_instructor=grid_instructor_out) + + assert grid_out.extent == "global" + + weights_cache_init(Path(tmp_path, "weights")) + w = Weights(grid_in=grid_in, grid_out=grid_out, method="conservative") + + assert w.method == "conservative" + assert ( + w.id + == "8edb4ee828dbebc2dc8e193281114093_bf73249f1725126ad3577727f3652019_peri_no-degen_conservative" + ) + assert ( + w.periodic != w.regridder.periodic + ) # xESMF resets periodic to False for conservative weights + assert w.id in w.filename + assert "xESMF_v" in w.tool + assert w.format == "xESMF" + + # default file_name = method_inputgrid_outputgrid_periodic" + assert w.regridder.filename == "conservative_80x180_120x240.nc" + + def test_from_id(self): + "Test creating a Weights object by reading weights from disk, identified by the id." + pass + + def test_from_disk(self): + "Test creating a Weights object by reading an xESMF or other weights file from disk." + pass + + def test_conservative_no_bnds(self, load_esgf_test_data, tmp_path): + "Test whether exception is raised when no bounds present for conservative remapping." + ds = xr.open_dataset(CORDEX_TAS_NO_BOUNDS) + gi = Grid(ds=ds) + go = Grid(grid_id="1deg", compute_bounds=True) + + assert gi.lat_bnds is None + assert gi.lon_bnds is None + assert go.lat_bnds is not None + assert go.lon_bnds is not None + + with pytest.raises( + Exception, + match="For conservative remapping, horizontal grid bounds have to be defined for the source and target grids.", + ): + weights_cache_init(Path(tmp_path, "weights")) + Weights(grid_in=gi, grid_out=go, method="conservative") + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_Weights_compute(tmp_path, load_esgf_test_data): + "Test the generation of Weights with the _compute method." + g = Grid(grid_id="1deg") + g_out = Grid(grid_id="2deg_lsm") + + weights_cache_init(Path(tmp_path, "weights")) + + # Exception should be raised if input and output grid are evaluated as equal + with pytest.raises( + Exception, + match="The selected source and target grids are the same. No regridding operation required.", + ): + Weights( + g, + Grid(grid_instructor=(0.0, 360.0, 1.0, -90.0, 90.0, 1.0)), + method="bilinear", + ) + + # Exception should be raised for conservative method if input or output grid do not contain bounds + with pytest.raises( + Exception, + match="For conservative remapping, horizontal grid bounds have to be defined for the source and target grids.", + ): + Weights(g, g_out, method="conservative") + + # Test computation and cache storage + g_out = Grid(grid_id="2deg_lsm", compute_bounds=True) + w = Weights(g, g_out, method="nearest_s2d") + assert w.id == w._read_info_from_cache("uid") + assert w.tool == w._read_info_from_cache("tool") + assert w.regridder.periodic == w.periodic + assert w._read_info_from_cache("method") == "nearest_s2d" + assert w.regridder.method == w.method + assert w.format == "xESMF" + assert w.regridder.filename == "nearest_s2d_180x360_90x180_peri.nc" + assert not w.regridder.reuse_weights + assert w.regridder.ignore_degenerate is None + assert w.regridder.n_in == 180 * 360 + assert w.regridder.n_out == 90 * 180 + assert w.ignore_degenerate is None + assert w.filename == "weights_" + w.id + ".nc" + + # Test weight reusage + z = Weights(g, g_out, method="nearest_s2d") + assert z.regridder.reuse_weights + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_Weights_compute_unstructured(tmp_path, load_esgf_test_data): + "Test the generation of Weights for unstructured grids with the _compute method." + ds = xr.open_dataset(CMIP6_UNSTR_ICON_A, use_cftime=True) + g = Grid(ds=ds) + g_out = Grid(grid_id="2deg_lsm", compute_bounds=True) + + weights_cache_init(Path(tmp_path, "weights")) + + # Exception should be raised for other than nearest_s2d remapping method + with pytest.raises( + Exception, + match="For unstructured grids, the only supported remapping method that is currently supported is nearest neighbour.", + ): + Weights(g, g_out, method="conservative") + + # Check translated xesmf settings + w = Weights(g, g_out, method="nearest_s2d") + assert w.regridder.sequence_in + assert not w.regridder.sequence_out + assert w.regridder.ignore_degenerate is None + assert w.regridder.n_in == g.ncells + assert w.regridder.n_out == 90 * 180 + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_Weights_generate_id(tmp_path, load_esgf_test_data): + "Test the generation of Weight ids." + g = Grid(grid_id="1deg") + g_out = Grid(grid_id="2pt5deg") + + weights_cache_init(Path(tmp_path, "weights")) + w = Weights(g, g_out, method="bilinear") + + assert w.id == w._generate_id() + assert w.id == "_".join([g.hash, g_out.hash, "peri", "no-degen", "bilinear"]) + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_Weights_init_with_collapsed_cells(tmp_path, load_esgf_test_data): + "Test the creation of remapping weights for a grid containing collapsed cells" + # ValueError: ESMC_FieldRegridStore failed with rc = 506. Please check the log files (named "*ESMF_LogFile"). + ds = xr.open_dataset(CMIP6_OCE_HALO_CNRM, use_cftime=True) + + g = Grid(ds=ds) + g_out = Grid(grid_instructor=(10.0,)) + + assert g.contains_collapsed_cells + + weights_cache_init(Path(tmp_path, "weights")) + Weights(g, g_out, method="bilinear") + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_Regridder_filename(tmp_path): + """Test that Regridder filename is reset properly.""" + g1 = Grid(grid_id="2pt5deg") + g2 = Grid(grid_id="2deg_lsm") + + weights_cache_init(Path(tmp_path, "weights")) + + w = Weights(g1, g2, method="nearest_s2d") + + assert w.regridder.filename == w.regridder._get_default_filename() + assert os.path.isfile(Path(tmp_path, "weights", w.filename)) + assert w.filename != w.regridder.filename + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_cache_init_and_flush(tmp_path): + "Test of the cache init and flush functionalities" + + weights_dir = Path(tmp_path, "clisops_weights") + weights_cache_init(weights_dir) + + gi = Grid(grid_instructor=20) + go = Grid(grid_instructor=10) + Weights(grid_in=gi, grid_out=go, method="nearest_s2d") + + flist = sorted(os.path.basename(f) for f in glob(f"{weights_dir}/*")) + flist_ref = [ + "grid_4d324aecaa8302ab0f2f212e9821b00f.nc", + "grid_96395935b4e81f2a5b55970bd920d82c.nc", + "weights_4d324aecaa8302ab0f2f212e9821b00f_96395935b4e81f2a5b55970bd920d82c_peri_no-degen_nearest_s2d.json", + "weights_4d324aecaa8302ab0f2f212e9821b00f_96395935b4e81f2a5b55970bd920d82c_peri_no-degen_nearest_s2d.nc", + ] + assert flist == flist_ref + + weights_cache_flush() + assert glob(f"{weights_dir}/*") == [] + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_cache_lock_mechanism(load_esgf_test_data, tmp_path): + """Test lock mechanism of local regrid weights cache.""" + ds = xr.open_dataset(CMIP6_TAS_ONE_TIME_STEP, use_cftime=True) + + grid_in = Grid(ds=ds) + grid_out = Grid(grid_instructor=10) + + # First round - creating the weights should work without problems + weights_cache_init(Path(tmp_path, "weights")) + w = Weights(grid_in=grid_in, grid_out=grid_out, method="nearest_s2d") + + # Remove grid files to suppress related warnings of already existing files + os.remove(Path(tmp_path, "weights", "grid_" + grid_in.hash + ".nc")) + os.remove(Path(tmp_path, "weights", "grid_" + grid_out.hash + ".nc")) + + # Second round, but manually put lockfile in place + LOCK_FILE = Path(tmp_path, "weights", w.filename + ".lock") + lock = FileLock(LOCK_FILE) + lock.acquire(timeout=10) + + # Fail test if lockfile is not recognized + with pytest.warns(UserWarning, match="lockfile") as issuedWarnings: + Weights(grid_in=grid_in, grid_out=grid_out, method="nearest_s2d") + if not issuedWarnings: + raise Exception("Lockfile not recognized/ignored.") + else: + assert len(issuedWarnings) == 1 + # for issuedWarning in issuedWarnings: + # print(issuedWarning.message) + + lock.release() + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_cache_reinit_and_write_protection(tmp_path): + """Test that Regridder does not write to cache if lockfile exists.""" + g1 = Grid(grid_id="2pt5deg") + g2 = Grid(grid_id="2deg_lsm") + + orig_cache_dir = CONFIG["clisops:grid_weights"]["local_weights_dir"] + weights_cache_init(Path(tmp_path, "weights")) + + # Create weights, get filename and flush cache + w = Weights(g1, g2, method="nearest_s2d") + fname = w.filename + weights_cache_flush() + + # Create lockfile + LOCK_FILE = Path(tmp_path, "weights", fname + ".lock") + lock = FileLock(LOCK_FILE) + lock.acquire(timeout=10) + + # recreate weights + w = Weights(g1, g2, method="nearest_s2d") + + # ensure that cache does not contain weight file and metadata + lock.release() + flist = sorted(os.path.basename(f) for f in glob(f"{Path(tmp_path, 'weights')}/*")) + assert all([f.startswith("grid_") for f in flist]) + assert len(flist) == 2 + assert Path(CONFIG["clisops:grid_weights"]["local_weights_dir"]) == Path( + tmp_path, "weights" + ) + assert CONFIG["clisops:grid_weights"]["local_weights_dir"] != orig_cache_dir + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_read_metadata(tmp_path): + """Test Weights method _read_info_from_cache.""" + g1 = Grid(grid_instructor=10.0) + g2 = Grid(grid_instructor=15.0) + + # Create weights and assert attributes written to cache + weights_cache_init(Path(tmp_path, "weights")) + w = Weights(g1, g2, method="nearest_s2d") + + assert w._read_info_from_cache("filename") == w.filename + assert w._read_info_from_cache("method") == "nearest_s2d" + assert w._read_info_from_cache("source_uid") == g1.hash + assert w._read_info_from_cache("target_extent") == g2.extent + assert w._read_info_from_cache("bla") is None + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +class TestRegrid: + def _setup(self): + if hasattr(self, "setup_done"): + return + + self.ds = xr.open_dataset(CMIP6_TAS_ONE_TIME_STEP, use_cftime=True) + self.grid_in = Grid(ds=self.ds) + + self.grid_instructor_out = (0, 360, 1.5, -90, 90, 1.5) + self.grid_out = Grid(grid_instructor=self.grid_instructor_out) + self.setup_done = True + + def test_adaptive_masking(self, load_esgf_test_data, tmp_path): + self._setup() + weights_cache_init(Path(tmp_path, "weights")) + w = Weights(grid_in=self.grid_in, grid_out=self.grid_out, method="conservative") + r = regrid(self.grid_in, self.grid_out, w, adaptive_masking_threshold=0.7) + print(r) + + def test_no_adaptive_masking(self, load_esgf_test_data, tmp_path): + self._setup() + weights_cache_init(Path(tmp_path, "weights")) + w = Weights(grid_in=self.grid_in, grid_out=self.grid_out, method="bilinear") + r = regrid(self.grid_in, self.grid_out, w, adaptive_masking_threshold=-1.0) + print(r) + + def test_duplicated_cells_warning_issued(self, load_esgf_test_data, tmp_path): + self._setup() + weights_cache_init(Path(tmp_path, "weights")) + w = Weights(grid_in=self.grid_in, grid_out=self.grid_out, method="conservative") + + # Cheat regrid into thinking, grid_in contains duplicated cells + self.grid_in.contains_duplicated_cells = True + + with pytest.warns( + UserWarning, + match="The grid of the selected dataset contains duplicated cells. " + "For the conservative remapping method, " + "duplicated grid cells contribute to the resulting value, " + "which is in most parts counter-acted by the applied re-normalization. " + "However, please be wary with the results and consider removing / masking " + "the duplicated cells before remapping.", + ) as issuedWarnings: + r = regrid(self.grid_in, self.grid_out, w, adaptive_masking_threshold=0.0) + if not issuedWarnings: + raise Exception( + "No warning issued regarding the duplicated cells in the grid." + ) + else: + assert len(issuedWarnings) == 1 + print(r) + + def test_regrid_dataarray(self, load_esgf_test_data, tmp_path): + self._setup() + weights_cache_init(Path(tmp_path, "weights")) + w = Weights(grid_in=self.grid_in, grid_out=self.grid_out, method="nearest_s2d") + grid_da = Grid(self.grid_in.ds.tas) + + vattrs = ( + "regrid_method", + "standard_name", + "long_name", + "comment", + "units", + "cell_methods", + "cell_measures", + "history", + ) + gattrs = ( + "grid", + "grid_label", + "regrid_operation", + "regrid_tool", + "regrid_weights_uid", + ) + + r1 = regrid(grid_da, self.grid_out, w, keep_attrs=True) + assert vattrs == tuple(r1["tas"].attrs.keys()) + assert gattrs == tuple(r1.attrs.keys()) + + r2 = regrid(grid_da, self.grid_out, w, keep_attrs=False) + assert ("regrid_method",) == tuple(r2["tas"].attrs.keys()) + assert gattrs == tuple(r2.attrs.keys()) + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_duplicated_cells_renormalization(load_esgf_test_data, tmp_path): + # todo: Should probably be an xesmf test as well, will do PR there in the future + ds = xr.open_dataset(CMIP6_STAGGERED_UCOMP, use_cftime=True) + + # some internal xesmf code to create array of ones + missing = np.isnan(ds.tauuo) + ds["tauuo"] = (~missing).astype("d") + + grid_in = Grid(ds=ds) + assert grid_in.contains_collapsed_cells is True + assert grid_in.contains_duplicated_cells is True + + # Make sure all values that are not missing, are equal to one + assert grid_in.ncells == ds["tauuo"].where(~missing, 1.0).sum() + # Make sure all values that are missing are equal to 0 + assert 0.0 == ds["tauuo"].where(missing, 0.0).sum() + + grid_out = Grid(grid_instructor=(0, 360, 1.5, -90, 90, 1.5)) + weights_cache_init(Path(tmp_path, "weights")) + w = Weights(grid_in=grid_in, grid_out=grid_out, method="conservative") + + # Remap using adaptive masking + r1 = regrid(grid_in, grid_out, w, adaptive_masking_threshold=0.5) + + # Remap using default setting (na_thres = 0.5) + r2 = regrid(grid_in, grid_out, w) + + # Make sure both options yield equal results + xr.testing.assert_equal(r1, r2) + + # Remap without using adaptive masking - internally, then adaptive masking is used + # with threshold 0., to still renormalize contributions from duplicated cells + # but not from masked cells or out-of-source-domain area + r3 = regrid(grid_in, grid_out, w, adaptive_masking_threshold=-1.0) + + # Make sure, contributions from duplicated cells (i.e. values > 1) are renormalized + assert r2["tauuo"].where(r2["tauuo"] > 1.0, 0.0).sum() == 0.0 + assert r3["tauuo"].where(r2["tauuo"] > 1.0, 0.0).sum() == 0.0 + + # Make sure xesmf behaves as expected: + # test that deactivated adaptive masking in xesmf will yield results > 1 + # and else, contributions from duplicated cells will be renormalized + r4 = w.regridder(ds["tauuo"], skipna=False) + r5 = w.regridder(ds["tauuo"], skipna=True, na_thres=0.0) + assert r4.where(r4 > 1.0, 0.0).sum() > 0.0 + assert r5.where(r5 > 1.0, 0.0).sum() == 0.0 diff --git a/tests/test_dataset_utils.py b/tests/test_dataset_utils.py index 5e1dab75..441b2980 100644 --- a/tests/test_dataset_utils.py +++ b/tests/test_dataset_utils.py @@ -3,6 +3,8 @@ import packaging.version import pytest import xarray as xr +from packaging.version import Version +from roocs_grids import get_grid_file import clisops.utils.dataset_utils as clidu from _common import ( @@ -13,10 +15,24 @@ CMIP6_OCE_HALO_CNRM, CMIP6_SICONC, CMIP6_TAS_ONE_TIME_STEP, + CMIP6_TAS_PRECISION_A, CMIP6_TOS_ONE_TIME_STEP, CMIP6_UNSTR_ICON_A, CORDEX_TAS_ONE_TIMESTEP, ) +from clisops.core.regrid import XESMF_MINIMUM_VERSION + +try: + import xesmf + + if Version(xesmf.__version__) < Version(XESMF_MINIMUM_VERSION): + raise ImportError() +except ImportError: + xesmf = None + +XESMF_IMPORT_MSG = ( + f"xESMF >= {XESMF_MINIMUM_VERSION} is needed for regridding functionalities." +) def test_add_day(): @@ -60,6 +76,147 @@ def test_date_out_of_expected_range(): ) +def test_add_hor_CF_coord_attrs(): + "Test function to add standard attributes to horizontal coordinate variables." + # Create basic dataset + ds = xr.Dataset( + data_vars={}, + coords={ + "lat": (["lat"], np.ones(1)), + "lon": (["lon"], np.ones(1)), + "lat_bnds": (["lat", "bnds"], np.ones((1, 2))), + "lon_bnds": (["lon", "bnds"], np.ones((1, 2))), + }, + ) + + # Ensuring attributes have been added + ds = clidu.add_hor_CF_coord_attrs(ds=ds) + assert ds["lat"].attrs["bounds"] == "lat_bnds" + assert ds["lon"].attrs["bounds"] == "lon_bnds" + assert ds["lat"].attrs["units"] == "degrees_north" + assert ds["lon"].attrs["units"] == "degrees_east" + assert ds["lat"].attrs["axis"] == "Y" + assert ds["lon"].attrs["axis"] == "X" + assert ds["lat"].attrs["standard_name"] == "latitude" + assert ds["lon"].attrs["standard_name"] == "longitude" + + # Ensuring attributes have been updated (and conflicting ones overwritten) + ds["lat"].attrs["bounds"] = "lat_b" + ds["lon"].attrs["standard_name"] = "lon" + ds["lat_bnds"].attrs["custom"] = "custom" + ds = clidu.add_hor_CF_coord_attrs(ds=ds, keep_attrs=True) + assert ds["lat"].attrs["bounds"] == "lat_bnds" + assert ds["lon"].attrs["bounds"] == "lon_bnds" + assert ds["lat"].attrs["units"] == "degrees_north" + assert ds["lon"].attrs["units"] == "degrees_east" + assert ds["lat"].attrs["axis"] == "Y" + assert ds["lon"].attrs["axis"] == "X" + assert ds["lat"].attrs["standard_name"] == "latitude" + assert ds["lon"].attrs["standard_name"] == "longitude" + assert ds["lat_bnds"].attrs["custom"] == "custom" + + # Incorrect coordinate variable name supplied should lead to a KeyError + with pytest.raises(KeyError) as exc: + ds = clidu.add_hor_CF_coord_attrs(ds, lat="latitude") + assert ( + str(exc.value) + == "'Not all specified coordinate variables exist in the dataset.'" + ) + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_reformat_xESMF_to_CF(): + "Test reformat operation reformat_xESMF_to_CF" + # Use xesmf utility function to create dataset with global grid + ds = xesmf.util.grid_global(5.0, 5.0) + + # It should have certain variables defined + assert all([coord in ds for coord in ["lat", "lon", "lat_b", "lon_b"]]) + assert all([dim in ds.dims for dim in ["x", "y", "x_b", "y_b"]]) + + # Reformat + ds.attrs["xesmf"] = xesmf.__version__ + ds_ref = clidu.reformat_xESMF_to_CF(ds=ds, keep_attrs=True) + assert all([coord in ds_ref for coord in ["lat", "lon", "lat_bnds", "lon_bnds"]]) + assert all([dim in ds_ref.dims for dim in ["lat", "lon", "bnds"]]) + assert ds_ref.attrs["xesmf"] == xesmf.__version__ + + +def test_reformat_SCRIP_to_CF(): + "Test reformat operation reformat_SCRIP_to_CF" + # Load dataset in SCRIP format (using roocs_grids) + ds = xr.open_dataset(get_grid_file("2pt5deg")) + + # It should have certain variables defined + assert all( + [ + coord in ds + for coord in [ + "grid_center_lat", + "grid_center_lon", + "grid_corner_lat", + "grid_corner_lon", + "grid_dims", + "grid_area", + "grid_imask", + ] + ] + ) + assert all([dim in ds.dims for dim in ["grid_corners", "grid_size", "grid_rank"]]) + + # Reformat + ds_ref = clidu.reformat_SCRIP_to_CF(ds=ds, keep_attrs=True) + assert all([coord in ds_ref for coord in ["lat", "lon", "lat_bnds", "lon_bnds"]]) + assert all([dim in ds_ref.dims for dim in ["lat", "lon", "bnds"]]) + assert ds_ref.attrs["Conventions"] == "SCRIP" + + +def test_detect_shape_regular(): + "Test detect_shape function for a regular grid" + # Load dataset + ds = xr.open_dataset(get_grid_file("0pt25deg_era5_lsm")) + + # Detect shape + nlat, nlon, ncells = clidu.detect_shape( + ds, lat="latitude", lon="longitude", grid_type="regular_lat_lon" + ) + + # Assertion + assert nlat == 721 + assert nlon == 1440 + assert ncells == nlat * nlon + + +def test_detect_shape_unstructured(): + "Test detect_shape function for an unstructured grid" + # Load dataset + ds = xr.open_dataset(CMIP6_UNSTR_ICON_A, use_cftime=True) + + # Detect shape + nlat, nlon, ncells = clidu.detect_shape( + ds, lat="latitude", lon="longitude", grid_type="unstructured" + ) + + # Assertion + assert nlat == ncells + assert nlon == ncells + assert ncells == 20480 + + +@pytest.mark.skipif(xesmf is None, reason=XESMF_IMPORT_MSG) +def test_detect_format(): + "Test detect_format function" + # Load/Create datasets in SCRIP, CF and xESMF format + ds_cf = xr.open_dataset(get_grid_file("0pt25deg_era5_lsm")) + ds_scrip = xr.open_dataset(get_grid_file("0pt25deg_era5")) + ds_xesmf = xesmf.util.grid_global(5.0, 5.0) + + # Assertion + assert clidu.detect_format(ds_cf) == "CF" + assert clidu.detect_format(ds_scrip) == "SCRIP" + assert clidu.detect_format(ds_xesmf) == "xESMF" + + def test_detect_coordinate_and_bounds(): "Test detect_bounds and detect_coordinate functions." ds_a = xr.open_mfdataset(C3S_CORDEX_AFR_TAS, use_cftime=True, combine="by_coords") @@ -310,3 +467,53 @@ def test_convert_lon_frame_shifted_no_bounds(): # todo: add a few more tests of cf_convert_lon_frame using xe.util functions to create regional and global datasets + + +def test_calculate_bounds_curvilinear(load_esgf_test_data): + "Test for bounds calculation for curvilinear grid" + + # Load CORDEX dataset (curvilinear grid) + ds = xr.open_dataset(CORDEX_TAS_ONE_TIMESTEP).isel( + {"rlat": range(100, 120), "rlon": range(100, 120)} + ) + + # Drop bounds variables and compute them + ds_nb = ds.drop_vars(["lon_vertices", "lat_vertices"]) + ds_nb = clidu.generate_bounds_curvilinear(ds_nb, lat="lat", lon="lon") + + # Sort every cells vertices values + for i in range(1, 19): + for j in range(1, 19): + ds.lat_vertices[i, j, :] = np.sort(ds.lat_vertices.values[i, j, :]) + ds.lon_vertices[i, j, :] = np.sort(ds.lon_vertices.values[i, j, :]) + ds_nb.lat_bnds[i, j, :] = np.sort(ds_nb.lat_bnds.values[i, j, :]) + ds_nb.lon_bnds[i, j, :] = np.sort(ds_nb.lon_bnds.values[i, j, :]) + + # Assert all values are close (discard cells at edge of selected grid area) + xr.testing.assert_allclose( + ds.lat_vertices.isel({"rlat": range(1, 19), "rlon": range(1, 19)}), + ds_nb.lat_bnds.isel({"rlat": range(1, 19), "rlon": range(1, 19)}), + rtol=1e-06, + atol=0, + ) + xr.testing.assert_allclose( + ds.lon_vertices.isel({"rlat": range(1, 19), "rlon": range(1, 19)}), + ds_nb.lon_bnds.isel({"rlat": range(1, 19), "rlon": range(1, 19)}), + rtol=1e-06, + atol=0, + ) + + +def test_calculate_bounds_rectilinear(load_esgf_test_data): + "Test for bounds calculation for rectilinear grid" + + # Load CORDEX dataset (curvilinear grid) + ds = xr.open_dataset(CMIP6_TAS_PRECISION_A) + + # Drop bounds variables and compute them + ds_nb = ds.drop_vars(["lon_bnds", "lat_bnds"]) + ds_nb = clidu.generate_bounds_rectilinear(ds_nb, lat="lat", lon="lon") + + # Assert all values are close + xr.testing.assert_allclose(ds.lat_bnds, ds_nb.lat_bnds, rtol=1e-06, atol=0) + xr.testing.assert_allclose(ds.lon_bnds, ds_nb.lon_bnds, rtol=1e-06, atol=0) diff --git a/tests/test_ops_regrid.py b/tests/test_ops_regrid.py new file mode 100644 index 00000000..e57ad8f9 --- /dev/null +++ b/tests/test_ops_regrid.py @@ -0,0 +1,281 @@ +import os +from pathlib import Path + +import cf_xarray # noqa +import pytest +import xarray as xr +from roocs_grids import get_grid_file, grid_dict + +from _common import ( + CMIP5_MRSOS_ONE_TIME_STEP, + CMIP6_ATM_VERT_ONE_TIMESTEP, + CMIP6_OCE_HALO_CNRM, + CMIP6_TOS_ONE_TIME_STEP, +) +from clisops.core.regrid import XESMF_MINIMUM_VERSION, weights_cache_init, xe +from clisops.ops.regrid import regrid + +XESMF_IMPORT_MSG = ( + f"xESMF >= {XESMF_MINIMUM_VERSION} is needed for regridding functionalities." +) + + +def _check_output_nc(result, fname="output_001.nc"): + assert fname in [os.path.basename(_) for _ in result] + + +def _load_ds(fpath): + return xr.open_mfdataset(fpath) + + +@pytest.mark.skipif(xe is None, reason=XESMF_IMPORT_MSG) +def test_regrid_basic(tmpdir, load_esgf_test_data, tmp_path): + "Test a basic regridding operation." + fpath = CMIP5_MRSOS_ONE_TIME_STEP + basename = os.path.splitext(os.path.basename(fpath))[0] + method = "nearest_s2d" + + weights_cache_init(Path(tmp_path, "weights")) + + result = regrid( + fpath, + method=method, + adaptive_masking_threshold=0.5, + grid="1deg", + output_dir=tmpdir, + output_type="netcdf", + file_namer="standard", + ) + + _check_output_nc( + result, fname=f"{basename}-20051201_regrid-{method}-180x360_cells_grid.nc" + ) + + +@pytest.mark.skipif(xe is None, reason=XESMF_IMPORT_MSG) +def test_regrid_grid_as_none(tmpdir, load_esgf_test_data, tmp_path): + """ + Test behaviour when none passed as method and grid - + should use the default regridding. + """ + fpath = CMIP5_MRSOS_ONE_TIME_STEP + + weights_cache_init(Path(tmp_path, "weights")) + + with pytest.raises( + Exception, + match="xarray.Dataset, grid_id or grid_instructor have to be specified as input.", + ): + regrid( + fpath, + grid=None, + output_dir=tmpdir, + output_type="netcdf", + file_namer="standard", + ) + + +@pytest.mark.skipif(xe is None, reason=XESMF_IMPORT_MSG) +@pytest.mark.parametrize("grid_id", sorted(grid_dict)) +def test_regrid_regular_grid_to_all_roocs_grids( + tmpdir, load_esgf_test_data, grid_id, tmp_path +): + "Test regridded a regular lat/lon field to all roocs grid types." + fpath = CMIP5_MRSOS_ONE_TIME_STEP + basename = os.path.splitext(os.path.basename(fpath))[0] + method = "nearest_s2d" + + weights_cache_init(Path(tmp_path, "weights")) + + result = regrid( + fpath, + method=method, + adaptive_masking_threshold=0.5, + grid=grid_id, + output_dir=tmpdir, + output_type="netcdf", + file_namer="standard", + ) + + nc_file = result[0] + assert os.path.basename(nc_file).startswith(f"{basename}-20051201_regrid-{method}-") + + # Can we read the output file + ds = xr.open_dataset(nc_file) + assert "mrsos" in ds + assert ds.mrsos.size > 100 + + +@pytest.mark.skipif(xe is None, reason=XESMF_IMPORT_MSG) +def test_regrid_keep_attrs(load_esgf_test_data, tmp_path): + "Test if dataset and variable attributes are kept / removed as specified." + fpath = CMIP6_TOS_ONE_TIME_STEP + method = "nearest_s2d" + + weights_cache_init(Path(tmp_path, "weights")) + + ds = xr.open_dataset(fpath).isel(time=0) + + # regrid - preserve input attrs + result = regrid( + ds, + method=method, + adaptive_masking_threshold=-1, + grid="2deg_lsm", + output_type="xarray", + ) + + # regrid - scrapping attrs + result_na = regrid( + ds, + method=method, + adaptive_masking_threshold=-1, + grid="2deg_lsm", + output_type="xarray", + keep_attrs=False, + ) + + # regrid - keep target attrs + result_ta = regrid( + ds, + method=method, + adaptive_masking_threshold=-1, + grid="2deg_lsm", + output_type="xarray", + keep_attrs="target", + ) + + ds_remap = result[0] + ds_remap_na = result_na[0] + ds_remap_ta = result_ta[0] + + assert "tos" in ds_remap and "tos" in ds_remap_na and "tos" in ds_remap_ta + assert all([key in ds_remap.tos.attrs.keys() for key in ds.tos.attrs.keys()]) + assert all( + [ + key in ds_remap.attrs.keys() + for key in ds.attrs.keys() + if key not in ["nominal_resolution"] + ] + ) + # todo: remove the restriction when nominal_resolution of the target grid is calculated in core/regrid.py + assert all([key not in ds_remap_na.tos.attrs.keys() for key in ds.tos.attrs.keys()]) + assert all( + [ + key not in ds_remap_na.attrs.keys() + for key in ds.attrs.keys() + if key not in ["grid", "grid_label"] + ] + ) + assert all([key in ds_remap_ta.tos.attrs.keys() for key in ds.tos.attrs.keys()]) + assert all( + [ + key not in ds_remap_ta.attrs.keys() + for key in ds.attrs.keys() + if key + not in ["source", "Conventions", "history", "NCO", "grid", "grid_label"] + ] + ) + + +@pytest.mark.skipif(xe is None, reason=XESMF_IMPORT_MSG) +def test_regrid_halo_simple(load_esgf_test_data, tmp_path): + "Test regridding with a simple halo." + fpath = CMIP6_TOS_ONE_TIME_STEP + ds = xr.open_dataset(fpath).isel(time=0) + + weights_cache_init(Path(tmp_path, "weights")) + + ds_out = regrid( + ds, + method="conservative", + adaptive_masking_threshold=-1, + grid=5, + output_type="xarray", + )[0] + + ## if halo removed + # assert ds_out.attrs["regrid_operation"] == "conservative_402x800_36x72" + ## if halo present + assert ds_out.attrs["regrid_operation"] == "conservative_404x802_36x72" + + +@pytest.mark.xfail +@pytest.mark.skipif(xe is None, reason=XESMF_IMPORT_MSG) +def test_regrid_halo_adv(load_esgf_test_data, tmp_path): + "Test regridding of dataset with a more complex halo." + fpath = CMIP6_OCE_HALO_CNRM + ds = xr.open_dataset(fpath).isel(time=0) + + weights_cache_init(Path(tmp_path, "weights")) + + ds_out = regrid( + ds, + method="conservative", + adaptive_masking_threshold=-1, + grid=5, + output_type="xarray", + )[0] + + # After the halo can be properly removed (maybe 1049x1440), the test can be updated + assert ds_out.attrs["regrid_operation"] == "conservative_1050x1442_36x72" + + +@pytest.mark.skipif(xe is None, reason=XESMF_IMPORT_MSG) +def test_regrid_same_grid_exception(tmpdir, tmp_path): + "Test that an exception is raised when source and target grid are the same." + fpath = get_grid_file("0pt25deg_era5") + + weights_cache_init(Path(tmp_path, "weights")) + + with pytest.raises( + Exception, + match="The selected source and target grids are the same. No regridding operation required.", + ): + regrid( + fpath, + method="conservative", + adaptive_masking_threshold=0.5, + grid="0pt25deg_era5_lsm", + output_dir=tmpdir, + output_type="netcdf", + file_namer="standard", + ) + + +@pytest.mark.skipif(xe is None, reason=XESMF_IMPORT_MSG) +def test_regrid_cmip6_nc_consistent_bounds_and_coords(load_esgf_test_data, tmpdir): + """Tests clisops regrid function and check metadata added by xarray""" + result = regrid( + ds=CMIP6_ATM_VERT_ONE_TIMESTEP, + method="nearest_s2d", + grid=10.0, + output_dir=tmpdir, + output_type="nc", + file_namer="standard", + ) + res = _load_ds(result) + # check fill value in bounds + assert "_FillValue" not in res.lat_bnds.encoding + assert "_FillValue" not in res.lon_bnds.encoding + assert "_FillValue" not in res.time_bnds.encoding + assert "_FillValue" not in res.lev_bnds.encoding + assert "_FillValue" not in res.ap_bnds.encoding + assert "_FillValue" not in res.b_bnds.encoding + # check fill value in coordinates + assert "_FillValue" not in res.time.encoding + assert "_FillValue" not in res.lat.encoding + assert "_FillValue" not in res.lon.encoding + assert "_FillValue" not in res.lev.encoding + assert "_FillValue" not in res.ap.encoding + assert "_FillValue" not in res.b.encoding + # check coordinates in bounds + assert "coordinates" not in res.lat_bnds.encoding + assert "coordinates" not in res.lon_bnds.encoding + assert "coordinates" not in res.time_bnds.encoding + assert "coordinates" not in res.lev_bnds.encoding + assert "coordinates" not in res.ap_bnds.encoding + assert "coordinates" not in res.b_bnds.encoding + # Check coordinates not in variable attributes + assert "coordinates" not in res.o3.encoding + assert "coordinates" not in res.ps.encoding diff --git a/tests/test_ops_xarray_mean.py b/tests/test_ops_xarray_mean.py index e71ceab7..16710101 100644 --- a/tests/test_ops_xarray_mean.py +++ b/tests/test_ops_xarray_mean.py @@ -119,7 +119,7 @@ def test_xarray_da_mean_keep_attrs_false(load_esgf_test_data): CMIP5_TAS, combine="by_coords", use_cftime=True, - ) + ).load() ds_tas_mean = ds.tas.mean(dim="time", keep_attrs=False) ds_mean = ds.mean(dim="time", keep_attrs=False) diff --git a/tests/test_output_utils.py b/tests/test_output_utils.py index 6fc605dc..7c9d921a 100644 --- a/tests/test_output_utils.py +++ b/tests/test_output_utils.py @@ -1,6 +1,7 @@ import os import sys import tempfile +import time from pathlib import Path import xarray as xr @@ -10,6 +11,7 @@ from clisops.utils.common import expand_wildcards from clisops.utils.file_namers import get_file_namer from clisops.utils.output_utils import ( + FileLock, get_chunk_length, get_da, get_output, @@ -271,3 +273,36 @@ def test_unify_chunks_cmip6(): assert chunked_ds1.chunks == chunked_ds2.chunks # test that ds = ds.unify_chunks hasn't changed ds.chunks assert chunked_ds2.chunks == chunked_ds2_unified.chunks + + +def test_filelock_simple(tmp_path): + LOCK_FILE = Path(tmp_path, "test.lock") + DATA_FILE = Path(tmp_path, "test.dat") + + lock = FileLock(LOCK_FILE) + + lock.acquire() + try: + time.sleep(1) + assert os.path.isfile(LOCK_FILE) + assert lock.state == "LOCKED" + open(DATA_FILE, "a").write("1") + finally: + lock.release() + + time.sleep(1) + assert not os.path.isfile(LOCK_FILE) + + +def test_filelock_already_locked(tmp_path): + LOCK_FILE = Path(tmp_path, "test.lock") + + lock1 = FileLock(LOCK_FILE) + lock2 = FileLock(LOCK_FILE) + + lock1.acquire(timeout=10) + + try: + lock2.acquire(timeout=5) + except Exception as exc: + assert str(exc) == f"Could not obtain file lock on {LOCK_FILE}" diff --git a/tox.ini b/tox.ini index 5a3d716d..26bef059 100644 --- a/tox.ini +++ b/tox.ini @@ -2,23 +2,23 @@ min_version = 4.0 envlist = py{38,39,310,311} - black + lint docs requires = - pip >= 21.0 + pip >= 23.0 opts = -v -[testenv:black] +[testenv:lint] skip_install = True basepython = python deps = - flake8 black + flake8 commands_pre = pip list commands = - flake8 clisops tests black --check clisops tests --exclude tests/mini-esgf-data + flake8 clisops tests [testenv:docs] extras = docs