Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load esmvalcore.dataset.Dataset objects in parallel using Dask #2517

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
27 changes: 6 additions & 21 deletions esmvalcore/cmor/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from __future__ import annotations

import logging
from collections import defaultdict
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Optional
Expand Down Expand Up @@ -137,7 +136,7 @@ def fix_metadata(
Returns
-------
iris.cube.CubeList
Fixed cubes.
A list containing a single fixed cube.

"""
# Update extra_facets with variable information given as regular arguments
Expand All @@ -161,27 +160,13 @@ def fix_metadata(
session=session,
frequency=frequency,
)
fixed_cubes = CubeList()

# Group cubes by input file and apply all fixes to each group element
# (i.e., each file) individually
by_file = defaultdict(list)
for cube in cubes:
by_file[cube.attributes.get("source_file", "")].append(cube)
for fix in fixes:
cubes = fix.fix_metadata(cubes)

for cube_list in by_file.values():
cube_list = CubeList(cube_list)
for fix in fixes:
cube_list = fix.fix_metadata(cube_list)

# The final fix is always GenericFix, whose fix_metadata method always
# returns a single cube
cube = cube_list[0]

cube.attributes.pop("source_file", None)
fixed_cubes.append(cube)

return fixed_cubes
# The final fix is always GenericFix, whose fix_metadata method always
# returns a single cube
return CubeList(cubes[:1])


def fix_data(
Expand Down
103 changes: 88 additions & 15 deletions esmvalcore/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
import re
import textwrap
import uuid
from collections.abc import Iterable
from copy import deepcopy
from fnmatch import fnmatchcase
from itertools import groupby
from pathlib import Path
from typing import Any, Iterator, Sequence, Union
from typing import Any, Iterator, Sequence, TypeVar, Union

import dask
from dask.delayed import Delayed
from iris.cube import Cube

from esmvalcore import esgf, local
Expand Down Expand Up @@ -84,6 +87,14 @@ def _ismatch(facet_value: FacetValue, pattern: FacetValue) -> bool:
)


T = TypeVar("T")


def _first(elems: Iterable[T]) -> T:
"""Return the first element."""
return next(iter(elems))


class Dataset:
"""Define datasets, find the related files, and load them.

Expand Down Expand Up @@ -693,9 +704,19 @@ def files(self) -> Sequence[File]:
def files(self, value):
self._files = value

def load(self) -> Cube:
def load(self, compute: bool = True) -> Cube | Delayed:
"""Load dataset.

Parameters
----------
compute:
If :obj:`True`, return the :class:`~iris.cube.Cube` immediately.
If :obj:`False`, return a :class:`~dask.delayed.Delayed` object
that can be used to load the cube by calling its
:meth:`~dask.delayed.Delayed.compute` method. Multiple datasets
can be loaded in parallel by passing a list of such delayeds
to :func:`dask.compute`.

Raises
------
InputFilesNotFound
Expand All @@ -718,7 +739,7 @@ def load(self) -> Cube:
supplementary_cubes.append(supplementary_cube)

output_file = _get_output_file(self.facets, self.session.preproc_dir)
cubes = preprocess(
cubes = dask.delayed(preprocess)(
[cube],
"add_supplementary_variables",
input_files=input_files,
Expand All @@ -727,7 +748,10 @@ def load(self) -> Cube:
supplementary_cubes=supplementary_cubes,
)

return cubes[0]
cube = dask.delayed(_first)(cubes)
if compute:
return cube.compute()
schlunma marked this conversation as resolved.
Show resolved Hide resolved
return cube

def _load(self) -> Cube:
"""Load self.files into an iris cube and return it."""
Expand All @@ -742,7 +766,16 @@ def _load(self) -> Cube:
msg = "\n".join(lines)
raise InputFilesNotFound(msg)

input_files = [
file.local_file(self.session["download_dir"])
if isinstance(file, esgf.ESGFFile)
else file
for file in self.files
]
output_file = _get_output_file(self.facets, self.session.preproc_dir)
debug = self.session["save_intermediary_cubes"]

# Load all input files and concatenate them.
fix_dir_prefix = Path(
self.session._fixed_file_dir,
self._get_joined_summary_facets("_", join_lists=True) + "_",
Expand All @@ -765,6 +798,51 @@ def _load(self) -> Cube:
**self.facets,
}
settings["concatenate"] = {"check_level": self.session["check_level"]}

result = []
for input_file in input_files:
Copy link
Contributor

@schlunma schlunma Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes how data is passed through the different preprocessor functions, doesn't it?

Right now, for example, fix_metadata will get ALL cubes from ALL files as input. With this change here, it will only get the cubes from one file, right?

I know that fix_metadata itself groups by file, but this is already very problematic (see #1806 and #2551).

I also fear that this might have other undesired side effects. Why do you need to treat these first preprocessor functions differently in the new code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to treat these first preprocessor functions differently in the new code?

To improve parallelism. Like this, each input file can be loaded and preprocessed up to the concatenate step in parallel.

This changes how data is passed through the different preprocessor functions, doesn't it?

No, it just takes the grouping out of fix_metadata and implements it in the function calling fix_metadata to enable additional parallelism. If this pull request is merged, #2551 would need to be updated to do the grouping here instead of inside fix_metadata.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I think I misunderstood the code in the first place. The function preprocess is not at all straightforward when it comes to handling of input and output types...I agree that the behavior has not changed.

I will test this with a couple of recipes once Levante is running again next week. In the meantime, would it make sense to remove the grouping of files in fix_metadata? It would be confusing to have this in two places of the code. I know that this wouldn't be strictly backwards-compatible, but the grouping was only enabled if the cubes have a source_file attribute (which is probably only the case when used within ESMValTool). I highly doubt that this function would be very useful outside of ESMValTool anyway.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the grouping in d5a39af, but where would you suggest we remove the "source_file" attribute now? Apart from grouping, it is also used to generate error messages from the cmor checkers. Should it be removed after cmor_check_data?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I would either remove it after cmor_check_data or remove it altogether from the code. The preprocessors log all filenames anyway now, so its not as important anymore as it used to be.

files = dask.delayed(preprocess)(
[input_file],
"fix_file",
input_files=[input_file],
output_file=output_file,
debug=debug,
**settings["fix_file"],
)
# Multiple cubes may be present in a file.
cubes = dask.delayed(preprocess)(
files,
"load",
input_files=[input_file],
output_file=output_file,
debug=debug,
**settings["load"],
)
# Combine the cubes into a single cube per file.
cubes = dask.delayed(preprocess, pure=False)(
cubes,
"fix_metadata",
input_files=[input_file],
output_file=output_file,
debug=debug,
**settings["fix_metadata"],
)
cube = dask.delayed(_first)(cubes)
result.append(cube)

# Concatenate the cubes from all files.
result = dask.delayed(preprocess, pure=False)(
result,
"concatenate",
input_files=input_files,
output_file=output_file,
debug=debug,
**settings["concatenate"],
)

# At this point `result` is a list containing a single cube. Apply the
# remaining preprocessor functions to this cube.
settings.clear()
settings["cmor_check_metadata"] = {
"check_level": self.session["check_level"],
"cmor_table": self.facets["project"],
Expand All @@ -777,6 +855,7 @@ def _load(self) -> Cube:
"timerange": self.facets["timerange"],
}
settings["fix_data"] = {
"pure": False,
"session": self.session,
**self.facets,
}
Expand All @@ -787,24 +866,18 @@ def _load(self) -> Cube:
"frequency": self.facets["frequency"],
"short_name": self.facets["short_name"],
}

result = [
file.local_file(self.session["download_dir"])
if isinstance(file, esgf.ESGFFile)
else file
for file in self.files
]
for step, kwargs in settings.items():
result = preprocess(
pure = settings.pop("pure", True)
result = dask.delayed(preprocess, pure=pure)(
result,
step,
input_files=self.files,
input_files=input_files,
output_file=output_file,
debug=self.session["save_intermediary_cubes"],
debug=debug,
**kwargs,
)

cube = result[0]
cube = dask.delayed(_first)(result)
return cube

def from_ranges(self) -> list["Dataset"]:
Expand Down
12 changes: 9 additions & 3 deletions tests/integration/dataset/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import iris.coords
import iris.cube
import pytest
from dask.delayed import Delayed

from esmvalcore.config import CFG
from esmvalcore.dataset import Dataset
Expand Down Expand Up @@ -55,7 +56,8 @@ def example_data(tmp_path, monkeypatch):
monkeypatch.setitem(CFG, "output_dir", tmp_path / "output_dir")


def test_load(example_data):
@pytest.mark.parametrize("lazy", [True, False])
def test_load(example_data, lazy):
tas = Dataset(
short_name="tas",
mip="Amon",
Expand All @@ -72,7 +74,11 @@ def test_load(example_data):
tas.find_files()
print(tas.files)

cube = tas.load()

if lazy:
result = tas.load(compute=False)
assert isinstance(result, Delayed)
cube = result.compute()
else:
cube = tas.load()
assert isinstance(cube, iris.cube.Cube)
assert cube.cell_measures()