diff --git a/CHANGELOG.md b/CHANGELOG.md index cbb8ea1..c30bb81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- add ability to derive variables from input datasets [\#34](https://github.com/mllam/mllam-data-prep/pull/34), @ealerskans - add github PR template to guide development process on github [\#44](https://github.com/mllam/mllam-data-prep/pull/44), @leifdenby ## [v0.5.0](https://github.com/mllam/mllam-data-prep/releases/tag/v0.5.0) diff --git a/README.md b/README.md index 5f5fcdf..034aa60 100644 --- a/README.md +++ b/README.md @@ -103,7 +103,7 @@ The package can also be used as a python module to create datasets directly, for import mllam_data_prep as mdp config_path = "example.danra.yaml" -config = mdp.Config.from_yaml_file(config_path) +config = mdp.Config.load_config(config_path) ds = mdp.create_dataset(config=config) ``` @@ -175,6 +175,18 @@ inputs: variables: # use surface incoming shortwave radiation as forcing - swavr0m + derived_variables: + # derive variables to be used as forcings + toa_radiation: + kwargs: + time: time + lat: lat + lon: lon + function: mllam_data_prep.ops.derived_variables.calculate_toa_radiation + hour_of_day: + kwargs: + time: time + function: mllam_data_prep.ops.derived_variables.calculate_hour_of_day dim_mapping: time: method: rename @@ -286,15 +298,26 @@ inputs: grid_index: method: stack dims: [x, y] - target_architecture_variable: state + target_output_variable: state danra_surface: path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr dims: [time, x, y] variables: - # shouldn't really be using sea-surface pressure as "forcing", but don't - # have radiation varibles in danra yet - - pres_seasurface + # use surface incoming shortwave radiation as forcing + - swavr0m + derived_variables: + # derive variables to be used as forcings + toa_radiation: + kwargs: + time: time + lat: lat + lon: lon + function: mllam_data_prep.derived_variables.calculate_toa_radiation + hour_of_day: + kwargs: + time: time + function: mllam_data_prep.derived_variables.calculate_hour_of_day dim_mapping: time: method: rename @@ -305,7 +328,7 @@ inputs: forcing_feature: method: stack_variables_by_var_name name_format: "{var_name}" - target_architecture_variable: forcing + target_output_variable: forcing ... ``` @@ -315,11 +338,44 @@ The `inputs` section defines the source datasets to extract data from. Each sour - `path`: the path to the source dataset. This can be a local path or a URL to e.g. a zarr dataset or netCDF file, anything that can be read by `xarray.open_dataset(...)`. - `dims`: the dimensions that the source dataset is expected to have. This is used to check that the source dataset has the expected dimensions and also makes it clearer in the config file what the dimensions of the source dataset are. - `variables`: selects which variables to extract from the source dataset. This may either be a list of variable names, or a dictionary where each key is the variable name and the value defines a dictionary of coordinates to do selection on. When doing selection you may also optionally define the units of the variable to check that the units of the variable match the units of the variable in the model architecture. -- `target_architecture_variable`: the variable in the model architecture that the source dataset should be mapped to. +- `target_output_variable`: the variable in the model architecture that the source dataset should be mapped to. - `dim_mapping`: defines how the dimensions of the source dataset should be mapped to the dimensions of the model architecture. This is done by defining a method to apply to each dimension. The methods are: - `rename`: simply rename the dimension to the new name - `stack`: stack the listed dimension to create the dimension in the output - `stack_variables_by_var_name`: stack the dimension into the new dimension, and also stack the variable name into the new variable name. This is useful when you have multiple variables with the same dimensions that you want to stack into a single variable. +- `derived_variables`: defines the variables to be derived from the variables available in the source dataset. This should be a dictionary where each key is the variable to be derived and the value defines a dictionary with the following additional information. See the 'Derived Variables' section for more details. + - `function`: the function to be used to derive a variable. This should be a string and may either be the full namespace of the function (e.g. `mllam_data_prep.ops.derived_variables.calculate_toa_radiation`) or in case the function is included in the `mllam_data_prep.ops.derived_variables` module it is enough with the function name only. + - `kwargs`: arguments for the function used to derive a variable. This is a dictionary where each key is the name of the variables to select from the source dataset and each value is the named argument to `function`. + +#### Derived Variables +Variables that are not part of the source dataset but can be derived from variables in the source dataset can also be included. They should be defined in their own section, called `derived_variables` as illustrated in the example config above and in the `example.danra.yaml` config file. + +To derive the variables, the function to be used to derive the variable (`function`) and the arguments to this function (`kwargs`) need to be specified, as explained above. In addition, an optional section called `attrs` can be added. In this section, the user can add attributes to the derived variable, as illustrated below. +```yaml + derived_variables: + toa_radiation: + kwargs: + time: time + lat: lat + lon: lon + function: mllam_data_prep.derived_variables.calculate_toa_radiation + attrs: + units: W*m**-2 + long_name: top-of-atmosphere incoming radiation +``` + +Note that the attributes `units` and `long_name` are required. This means that if the function used to derive a variable does not set these attributes they are **required** to be set in the config file. If using a function defined in `mllam_data_prep.ops.derived_variables` the `attrs` section is optional as the attributes should already be defined. In this case, adding the `units` and `long_name` attributes to the `attrs` section of the derived variable in config file will overwrite the already-defined attributes from the function. + +Currently, the following derived variables are included as part of `mllam-data-prep`: +- `toa_radiation`: + - Top-of-atmosphere incoming radiation + - function: `mllam_data_prep.ops.derived_variables.calculate_toa_radiation` +- `hour_of_day`: + - Hour of day (cyclically encoded) + - function: `mllam_data_prep.ops.derived_variables.calculate_hour_of_day` +- `day_of_year`: + - Day of year (cyclically encoded) + - function: `mllam_data_prep.ops.derived_variables.calculate_day_of_year` ### Config schema versioning diff --git a/example.danra.yaml b/example.danra.yaml index 3edf126..30682ff 100644 --- a/example.danra.yaml +++ b/example.danra.yaml @@ -61,6 +61,18 @@ inputs: variables: # use surface incoming shortwave radiation as forcing - swavr0m + derived_variables: + # derive variables to be used as forcings + toa_radiation: + kwargs: + time: time + lat: lat + lon: lon + function: mllam_data_prep.ops.derived_variables.calculate_toa_radiation + hour_of_day: + kwargs: + time: time + function: mllam_data_prep.ops.derived_variables.calculate_hour_of_day dim_mapping: time: method: rename diff --git a/mllam_data_prep/config.py b/mllam_data_prep/config.py index 14e80ef..93f407b 100644 --- a/mllam_data_prep/config.py +++ b/mllam_data_prep/config.py @@ -52,6 +52,28 @@ class ValueSelection: units: str = None +@dataclass +class DerivedVariable: + """ + Defines a derived variables, where the kwargs (variables required + for the calculation) and the function (for calculating the variable) + are specified. Optionally, in case a function does not return an + `xr.DataArray` with the required attributes (`units` and `long_name`) set, + these should be specified in `attrs`, e.g. + {"attrs": "units": "W*m**-2, "long_name": "top-of-the-atmosphere radiation"}. + Additional attributes can also be set if desired. + + Attributes: + kwargs: Variables required for calculating the derived variable. + function: Function used to calculate the derived variable. + attrs: Attributes (e.g. `units` and `long_name`) to set for the derived variable. + """ + + kwargs: Dict[str, str] + function: str + attrs: Optional[Dict[str, str]] = field(default_factory=dict) + + @dataclass class DimMapping: """ @@ -120,7 +142,8 @@ class InputDataset: 1) the path to the dataset, 2) the expected dimensions of the dataset, 3) the variables to select from the dataset (and optionally subsection - along the coordinates for each variable) and finally + along the coordinates for each variable) or the variables to derive + from the dataset, and finally 4) the method by which the dimensions and variables of the dataset are mapped to one of the output variables (this includes stacking of all the selected variables into a new single variable along a new coordinate, @@ -134,11 +157,6 @@ class InputDataset: dims: List[str] List of the expected dimensions of the dataset. E.g. `["time", "x", "y"]`. These will be checked to ensure consistency of the dataset being read. - variables: Union[List[str], Dict[str, Dict[str, ValueSelection]]] - List of the variables to select from the dataset. E.g. `["temperature", "precipitation"]` - or a dictionary where the keys are the variable names and the values are dictionaries - defining the selection for each variable. E.g. `{"temperature": levels: {"values": [1000, 950, 900]}}` - would select the "temperature" variable and only the levels 1000, 950, and 900. dim_mapping: Dict[str, DimMapping] Mapping of the variables and dimensions in the input dataset to the dimensions of the output variable (`target_output_variable`). The key is the name of the output dimension to map to @@ -151,14 +169,23 @@ class InputDataset: (e.g. two datasets that coincide in space and time will only differ in the feature dimension, so the two will be combined by concatenating along the feature dimension). If a single shared coordinate cannot be found then an exception will be raised. + variables: Union[List[str], Dict[str, Dict[str, ValueSelection]]] + List of the variables to select from the dataset. E.g. `["temperature", "precipitation"]` + or a dictionary where the keys are the variable names and the values are dictionaries + defining the selection for each variable. E.g. `{"temperature": levels: {"values": [1000, 950, 900]}}` + would select the "temperature" variable and only the levels 1000, 950, and 900. + derived_variables: Dict[str, DerivedVariable] + Dictionary of variables to derive from the dataset, where the keys are the names variables will be given and + the values are `DerivedVariable` definitions that specify how to derive a variable. """ path: str dims: List[str] - variables: Union[List[str], Dict[str, Dict[str, ValueSelection]]] dim_mapping: Dict[str, DimMapping] target_output_variable: str - attributes: Dict[str, Any] = None + variables: Optional[Union[List[str], Dict[str, Dict[str, ValueSelection]]]] = None + derived_variables: Optional[Dict[str, DerivedVariable]] = None + attributes: Optional[Dict[str, Any]] = field(default_factory=dict) @dataclass @@ -258,7 +285,7 @@ class Output: variables: Dict[str, List[str]] coord_ranges: Dict[str, Range] = None - chunking: Dict[str, int] = None + chunking: Dict[str, int] = field(default_factory=dict) splitting: Splitting = None @@ -301,6 +328,54 @@ class Config(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard): class _(JSONWizard.Meta): raise_on_unknown_json_key = True + @staticmethod + def load_config(*args, **kwargs): + """ + Wrapper function for `from_yaml_file` to load config file and validate that: + - either `variables` or `derived_variables` are present in the config + - if both `variables` and `derived_variables` are present, that they don't + add the same variables to the dataset + + Parameters + ---------- + *args: Positional arguments for `from_yaml_file` + **kwargs: Keyword arguments for `from_yaml_file` + + Returns + ------- + config: Config + """ + + # Load the config + config = Config.from_yaml_file(*args, **kwargs) + + for input_dataset in config.inputs.values(): + if not input_dataset.variables and not input_dataset.derived_variables: + raise InvalidConfigException( + "At least one of the keys `variables` and `derived_variables` must be included" + " in the input dataset." + ) + elif input_dataset.variables and input_dataset.derived_variables: + # Check so that there are no overlapping variables + if isinstance(input_dataset.variables, list): + variable_vars = input_dataset.variables + elif isinstance(input_dataset.variables, dict): + variable_vars = input_dataset.variables.keys() + else: + raise TypeError( + f"Expected an instance of list or dict, but got {type(input_dataset.variables)}." + ) + derived_variable_vars = input_dataset.derived_variables.keys() + common_vars = list(set(variable_vars) & set(derived_variable_vars)) + if len(common_vars) > 0: + raise InvalidConfigException( + "Both `variables` and `derived_variables` include the following variables name(s):" + f" '{', '.join(common_vars)}'. This is not allowed. Make sure that there" + " are no overlapping variable names between `variables` and `derived_variables`," + f" either by renaming or removing '{', '.join(common_vars)}' from one of them." + ) + return config + if __name__ == "__main__": import argparse @@ -311,7 +386,7 @@ class _(JSONWizard.Meta): ) args = argparser.parse_args() - config = Config.from_yaml_file(args.f) + config = Config.load_config(args.f) import rich rich.print(config) diff --git a/mllam_data_prep/create_dataset.py b/mllam_data_prep/create_dataset.py index 73996cf..93cf82d 100644 --- a/mllam_data_prep/create_dataset.py +++ b/mllam_data_prep/create_dataset.py @@ -10,10 +10,12 @@ from . import __version__ from .config import Config, InvalidConfigException -from .ops.loading import load_and_subset_dataset +from .ops.derived_variables import derive_variables +from .ops.loading import load_input_dataset from .ops.mapping import map_dims_and_variables from .ops.selection import select_by_kwargs from .ops.statistics import calc_stats +from .ops.subsetting import subset_dataset # the `extra` field in the config that was added between v0.2.0 and v0.5.0 is # optional, so we can support both v0.2.0 and v0.5.0 @@ -30,11 +32,14 @@ def _check_dataset_attributes(ds, expected_attributes, dataset_name): # check for attributes having the wrong value incorrect_attributes = { - k: v for k, v in expected_attributes.items() if ds.attrs[k] != v + key: val for key, val in expected_attributes.items() if ds.attrs[key] != val } if len(incorrect_attributes) > 0: s_list = "\n".join( - [f"{k}: {v} != {ds.attrs[k]}" for k, v in incorrect_attributes.items()] + [ + f"{key}: {val} != {ds.attrs[key]}" + for key, val in incorrect_attributes.items() + ] ) raise ValueError( f"Dataset {dataset_name} has the following incorrect attributes: {s_list}" @@ -120,23 +125,50 @@ def create_dataset(config: Config): output_config = config.output output_coord_ranges = output_config.coord_ranges + chunking_config = config.output.chunking dataarrays_by_target = defaultdict(list) for dataset_name, input_config in config.inputs.items(): path = input_config.path variables = input_config.variables + derived_variables = input_config.derived_variables target_output_var = input_config.target_output_variable - expected_input_attributes = input_config.attributes or {} + expected_input_attributes = input_config.attributes expected_input_var_dims = input_config.dims output_dims = output_config.variables[target_output_var] logger.info(f"Loading dataset {dataset_name} from {path}") try: - ds = load_and_subset_dataset(fp=path, variables=variables) + ds_input = load_input_dataset(fp=path) except Exception as ex: raise Exception(f"Error loading dataset {dataset_name} from {path}") from ex + + # Initialize the output dataset and add dimensions + ds = xr.Dataset() + ds.attrs.update(ds_input.attrs) + for dim in ds_input.dims: + ds = ds.assign_coords({dim: ds_input.coords[dim]}) + + if variables: + logger.info(f"Subsetting dataset {dataset_name}") + ds = subset_dataset( + ds_subset=ds, + ds_input=ds_input, + variables=variables, + chunking=chunking_config, + ) + + if derived_variables: + logger.info(f"Deriving variables from {dataset_name}") + ds = derive_variables( + ds=ds, + ds_input=ds_input, + derived_variables=derived_variables, + chunking=chunking_config, + ) + _check_dataset_attributes( ds=ds, expected_attributes=expected_input_attributes, @@ -191,9 +223,8 @@ def create_dataset(config: Config): # default to making a single chunk for each dimension if chunksize is not specified # in the config - chunking_config = config.output.chunking or {} logger.info(f"Chunking dataset with {chunking_config}") - chunks = {d: chunking_config.get(d, int(ds[d].count())) for d in ds.dims} + chunks = {dim: chunking_config.get(dim, int(ds[dim].count())) for dim in ds.dims} ds = ds.chunk(chunks) splitting = config.output.splitting @@ -255,7 +286,7 @@ def create_dataset_zarr(fp_config, fp_zarr: str = None): The path to the zarr file to write the dataset to. If not provided, the zarr file will be written to the same directory as the config file with the extension changed to '.zarr'. """ - config = Config.from_yaml_file(file=fp_config) + config = Config.load_config(file=fp_config) ds = create_dataset(config=config) diff --git a/mllam_data_prep/ops/chunking.py b/mllam_data_prep/ops/chunking.py new file mode 100644 index 0000000..12731e1 --- /dev/null +++ b/mllam_data_prep/ops/chunking.py @@ -0,0 +1,44 @@ +import numpy as np +from loguru import logger + +# Max chunk size warning +CHUNK_MAX_SIZE_WARNING = 1 * 1024**3 # 1GB + + +def check_chunk_size(ds, chunks): + """ + Check the chunk size and warn if it exceed CHUNK_MAX_SIZE_WARNING. + + Parameters + ---------- + ds: xr.Dataset + Dataset to be chunked + chunks: Dict[str, int] + Dictionary with keys as dimensions to be chunked and + chunk sizes as the values + + Returns + ------- + ds: xr.Dataset + Dataset with chunking applied + """ + + # Check the chunk size + for var_name, var_data in ds.data_vars.items(): + total_size = 1 + + for dim, chunk_size in chunks.items(): + dim_size = ds.sizes.get(dim, None) + if dim_size is None: + raise KeyError(f"Dimension '{dim}' not found in the dataset.") + total_size *= chunk_size + + dtype = var_data.dtype + bytes_per_element = np.dtype(dtype).itemsize + + memory_usage = total_size * bytes_per_element + + if memory_usage > CHUNK_MAX_SIZE_WARNING: + logger.warning( + f"The chunk size for '{var_name}' exceeds '{CHUNK_MAX_SIZE_WARNING}' GB." + ) diff --git a/mllam_data_prep/ops/derived_variables.py b/mllam_data_prep/ops/derived_variables.py new file mode 100644 index 0000000..7502deb --- /dev/null +++ b/mllam_data_prep/ops/derived_variables.py @@ -0,0 +1,511 @@ +""" +Handle deriving new variables (xr.DataArrays) from an individual input dataset +that has been loaded. This makes it possible to for example add fields that can +be derived from analytical expressions and are functions of coordinate values +(e.g. top-of-atmosphere incoming radiation is a function of time and lat/lon location), +but also of other physical fields (wind-speed is a function of both meridional +and zonal wind components). +""" +import datetime +import importlib +import sys + +import numpy as np +import xarray as xr +from loguru import logger + +from .chunking import check_chunk_size + +REQUIRED_FIELD_ATTRIBUTES = ["units", "long_name"] + + +def derive_variables(ds, ds_input, derived_variables, chunking): + """ + Load the dataset, and derive the specified variables + + Parameters + --------- + ds : xr.Dataset + Output dataset + ds_input : xr.Dataset + Input/source dataset + derived_variables : Dict[str, DerivedVariable] + Dictionary with the variables to derive with keys as the variable + names and values with entries for kwargs and function to use in + the calculation + chunking: Dict[str, int] + Dictionary with keys as the dimensions to chunk along and values + with the chunk size + + Returns + ------- + xr.Dataset + Dataset with derived variables included + """ + + target_dims = list(ds_input.sizes.keys()) + + for _, derived_variable in derived_variables.items(): + required_kwargs = derived_variable.kwargs + function_name = derived_variable.function + expected_field_attributes = derived_variable.attrs + + # Separate the lat,lon from the required variables as these will be derived separately + logger.warning( + "Assuming that the lat/lon coordinates are given as variables called" + " 'lat' and 'lon'." + ) + latlon_coords_to_include = {} + for key in list(required_kwargs.keys()): + if key in ["lat", "lon"]: + latlon_coords_to_include[key] = required_kwargs.pop(key) + + # Get subset of input dataset for calculating derived variables + ds_subset = ds_input[required_kwargs.keys()] + + # Any coordinates needed for the derivation, for which chunking should be performed, + # should be converted to variables since it is not possible for *indexed* coordinates + # to be chunked dask arrays + chunks = { + dim: chunking.get(dim, int(ds_subset[dim].count())) + for dim in ds_subset.dims + } + required_coordinates = [ + req_var for req_var in required_kwargs.keys() if req_var in ds_subset.coords + ] + ds_subset = ds_subset.drop_indexes(required_coordinates, errors="ignore") + for req_coord in required_coordinates: + if req_coord in chunks: + ds_subset = ds_subset.reset_coords(req_coord) + + # Chunk the dataset + ds_subset = _chunk_dataset(ds_subset, chunks) + + # Add function arguments to kwargs + kwargs = {} + if len(latlon_coords_to_include): + latlon = get_latlon_coords_for_input(ds_input) + for key, val in latlon_coords_to_include.items(): + kwargs[val] = latlon[key] + kwargs.update({val: ds_subset[key] for key, val in required_kwargs.items()}) + func = _get_derived_variable_function(function_name) + # Calculate the derived variable + derived_field = func(**kwargs) + + # Check that the derived field has the necessary attributes (REQUIRED_FIELD_ATTRIBUTES) + # set, return any dropped/reset coordinates, align it to the output dataset dimensions + # (if necessary) and add it to the dataset + if isinstance(derived_field, xr.DataArray): + derived_field = _check_for_required_attributes( + derived_field, expected_field_attributes + ) + derived_field = _return_dropped_coordinates( + derived_field, ds_subset, required_coordinates, chunks + ) + derived_field = _align_derived_variable( + derived_field, ds_input, target_dims + ) + ds[derived_field.name] = derived_field + elif isinstance(derived_field, tuple) and all( + isinstance(field, xr.DataArray) for field in derived_field + ): + for field in derived_field: + field = _check_for_required_attributes(field, expected_field_attributes) + field = _return_dropped_coordinates( + field, ds_subset, required_coordinates, chunks + ) + field = _align_derived_variable(field, ds_input, target_dims) + ds[field.name] = field + else: + raise TypeError( + "Expected an instance of xr.DataArray or tuple(xr.DataArray)," + f" but got {type(derived_field)}." + ) + + return ds + + +def _chunk_dataset(ds, chunks): + """ + Check the chunk size and chunk dataset. + + Parameters + ---------- + ds: xr.Dataset + Dataset to be chunked + chunks: Dict[str, int] + Dictionary with keys as dimensions to be chunked and + chunk sizes as the values + + Returns + ------- + ds: xr.Dataset + Dataset with chunking applied + """ + # Check the chunk size + check_chunk_size(ds, chunks) + + # Try chunking + try: + ds = ds.chunk(chunks) + except Exception as ex: + raise Exception(f"Error chunking dataset: {ex}") + + return ds + + +def _get_derived_variable_function(function_namespace): + """ + Function for getting the function for deriving + the specified variable. + + Parameters + ---------- + function_namespace: str + The full function namespace or just the function name + if it is a function included in this module. + + Returns + ------- + function: object + Function for deriving the specified variable + """ + # Get the name of the calling module + calling_module = globals()["__name__"] + + # Get module and function names + module_name, _, function_name = function_namespace.rpartition(".") + + # Check if the module_name is pointing to here (the calling module or empty "") + # If it does, then use globals() to get the function otherwise import the + # correct module and get the correct function + if module_name in [calling_module, ""]: + function = globals().get(function_name) + if not function: + raise TypeError( + f"Function '{function_namespace}' was not found in '{calling_module}'." + f" Check that you have specified the correct function name" + " and/or that you have defined the full function namespace if you" + " want to use a function defined outside of of the current module" + f" '{calling_module}'." + ) + else: + # Check if the module is already imported + if module_name in sys.modules: + module = module_name + else: + module = importlib.import_module(module_name) + + # Get the function from the module + function = getattr(module, function_name) + + return function + + +def _check_for_required_attributes(field, expected_attributes): + """ + Check the attributes of the derived variable. + + Parameters + ---------- + field: xr.DataArray + The derived field + expected_attributes: Dict[str, str] + Dictionary with expected attributes for the derived variables. + Defined in the config file. + + Returns + ------- + field: xr.DataArray + The derived field + """ + for attribute in REQUIRED_FIELD_ATTRIBUTES: + if attribute not in field.attrs or field.attrs[attribute] is None: + if attribute in expected_attributes.keys(): + field.attrs[attribute] = expected_attributes[attribute] + else: + # The expected attributes are empty and the attributes have not been + # set during the calculation of the derived variable + raise KeyError( + f'The attribute "{attribute}" has not been set for the derived' + f' variable "{field.name}". This is most likely because you are' + " using a function external to `mlllam-data-prep` to derive the field," + f" in which the required attributes ({', '.join(REQUIRED_FIELD_ATTRIBUTES)})" + " are not set. If they are not set in the function call when deriving the field," + ' they can be set in the config file by adding an "attrs" section under the' + f' "{field.name}" derived variable section. For example, if the required attributes' + f" ({', '.join(REQUIRED_FIELD_ATTRIBUTES)}) are not set for a derived variable named" + f' "toa_radiation" they can be set by adding the following to the config file:' + ' {"attrs": {"units": "W*m**-2", "long_name": "top-of-atmosphere incoming radiation"}}.' + ) + elif attribute in expected_attributes.keys(): + logger.warning( + f"The attribute '{attribute}' of the derived field" + f" {field.name} is being overwritten from" + f" '{field.attrs[attribute]}' to" + f" '{expected_attributes[attribute]}' according" + " to the specification in the config file." + ) + field.attrs[attribute] = expected_attributes[attribute] + else: + # Attributes are set in the funciton and nothing has been defined in the config file + pass + + return field + + +def _return_dropped_coordinates(derived_field, ds, required_coordinates, chunks): + """ + Return the coordinates that have been dropped/reset. + + Parameters + ---------- + derived_field: xr.Dataset + Derived variable + ds: xr.Dataset + Dataset with required coordinatwes + required_coordinates: List[str] + List of coordinates required for the derived variable + chunks: Dict[str, int] + Dictionary with keys as dimensions to be chunked and + chunk sizes as the values + + Returns + ------- + derived_field: xr.Dataset + Derived variable, now also with dropped coordinates returned + """ + for req_coord in required_coordinates: + if req_coord in chunks: + derived_field.coords[req_coord] = ds[req_coord] + + return derived_field + + +def _align_derived_variable(field, ds, target_dims): + """ + Align a derived variable to the target dimensions (ignoring non-dimension coordinates). + + Parameters + ---------- + field: xr.DataArray + Derived field to align + ds: xr.Dataset + Target dataset + target_dims: List[str] + Dimensions to align to (e.g. 'time', 'y', 'x') + + Returns + ------- + field: xr.DataArray + The derived field aligned to the target dimensions + """ + # Ensure that dimensions are ordered correctly + field = field.transpose( + *[dim for dim in target_dims if dim in field.dims], missing_dims="ignore" + ) + + # Add missing dimensions explicitly + for dim in target_dims: + if dim not in field.dims: + field = field.expand_dims({dim: ds.sizes[dim]}) + + # Broadcast to match only the target dimensions + broadcast_shape = {dim: ds[dim] for dim in target_dims if dim in ds.dims} + field = field.broadcast_like(xr.Dataset(coords=broadcast_shape)) + + return field + + +def calculate_toa_radiation(lat, lon, time): + """ + Function for calculating top-of-atmosphere incoming radiation + + Parameters + ---------- + lat : Union[xr.DataArray, float] + Latitude values. Should be in the range [-90, 90] + lon : Union[xr.DataArray, float] + Longitude values. Should be in the range [-180, 180] or [0, 360] + time : Union[xr.DataArray, datetime.datetime] + Time + + Returns + ------- + toa_radiation : Union[xr.DataArray, float] + Top-of-atmosphere incoming radiation + """ + logger.info("Calculating top-of-atmosphere incoming radiation") + + # Solar constant + solar_constant = 1366 # W*m**-2 + + # Different handling if xr.DataArray or datetime object + if isinstance(time, xr.DataArray): + day = time.dt.dayofyear + hour_utc = time.dt.hour + elif isinstance(time, datetime.datetime): + day = time.timetuple().tm_yday + hour_utc = time.hour + else: + raise TypeError( + "Expected an instance of xr.DataArray or datetime object," + f" but got {type(time)}." + ) + + # Eq. 1.6.1a in Solar Engineering of Thermal Processes 4th ed. + # dec: declination - angular position of the sun at solar noon w.r.t. + # the plane of the equator + dec = np.pi / 180 * 23.45 * np.sin(2 * np.pi * (284 + day) / 365) + + utc_solar_time = hour_utc + lon / 15 + hour_angle = 15 * (utc_solar_time - 12) + + # Eq. 1.6.2 with beta=0 in Solar Engineering of Thermal Processes 4th ed. + # cos_sza: Cosine of solar zenith angle + cos_sza = np.sin(lat * np.pi / 180) * np.sin(dec) + np.cos( + lat * np.pi / 180 + ) * np.cos(dec) * np.cos(hour_angle * np.pi / 180) + + # Where TOA radiation is negative, set to 0 + toa_radiation = xr.where(solar_constant * cos_sza < 0, 0, solar_constant * cos_sza) + + if isinstance(toa_radiation, xr.DataArray): + # Add attributes + toa_radiation.name = "toa_radiation" + toa_radiation.attrs["long_name"] = "top-of-atmosphere incoming radiation" + toa_radiation.attrs["units"] = "W*m**-2" + + return toa_radiation + + +def calculate_hour_of_day(time): + """ + Function for calculating hour of day features with a cyclic encoding + + Parameters + ---------- + time : Union[xr.DataArray, datetime.datetime] + Time + + Returns + ------- + hour_of_day_cos: Union[xr.DataArray, float] + cosine of the hour of day + hour_of_day_sin: Union[xr.DataArray, float] + sine of the hour of day + """ + logger.info("Calculating hour of day") + + # Get the hour of the day + if isinstance(time, xr.DataArray): + hour_of_day = time.dt.hour + elif isinstance(time, datetime.datetime): + hour_of_day = time.hour + else: + raise TypeError( + "Expected an instance of xr.DataArray or datetime object," + f" but got {type(time)}." + ) + + # Cyclic encoding of hour of day + hour_of_day_cos, hour_of_day_sin = cyclic_encoding(hour_of_day, 24) + + if isinstance(hour_of_day_cos, xr.DataArray): + # Add attributes + hour_of_day_cos.name = "hour_of_day_cos" + hour_of_day_cos.attrs[ + "long_name" + ] = "Cosine component of cyclically encoded hour of day" + hour_of_day_cos.attrs["units"] = "1" + + if isinstance(hour_of_day_sin, xr.DataArray): + # Add attributes + hour_of_day_sin.name = "hour_of_day_sin" + hour_of_day_sin.attrs[ + "long_name" + ] = "Sine component of cyclically encoded hour of day" + hour_of_day_sin.attrs["units"] = "1" + + return hour_of_day_cos, hour_of_day_sin + + +def calculate_day_of_year(time): + """ + Function for calculating day of year features with a cyclic encoding + + Parameters + ---------- + time : Union[xr.DataArray, datetime.datetime] + Time + + Returns + ------- + day_of_year_cos: Union[xr.DataArray, float] + cosine of the day of year + day_of_year_sin: Union[xr.DataArray, float] + sine of the day of year + """ + logger.info("Calculating day of year") + + # Get the day of year + if isinstance(time, xr.DataArray): + day_of_year = time.dt.dayofyear + elif isinstance(time, datetime.datetime): + day_of_year = time.timetuple().tm_yday + else: + raise TypeError( + "Expected an instance of xr.DataArray or datetime object," + f" but got {type(time)}." + ) + + # Cyclic encoding of day of year - use 366 to include leap years! + day_of_year_cos, day_of_year_sin = cyclic_encoding(day_of_year, 366) + + if isinstance(day_of_year_cos, xr.DataArray): + # Add attributes + day_of_year_cos.name = "day_of_year_cos" + day_of_year_cos.attrs[ + "long_name" + ] = "Cosine component of cyclically encoded day of year" + day_of_year_cos.attrs["units"] = "1" + + if isinstance(day_of_year_sin, xr.DataArray): + # Add attributes + day_of_year_sin.name = "day_of_year_sin" + day_of_year_sin.attrs[ + "long_name" + ] = "Sine component of cyclically encoded day of year" + day_of_year_sin.attrs["units"] = "1" + + return day_of_year_cos, day_of_year_sin + + +def cyclic_encoding(data, data_max): + """ + Cyclic encoding of data + + Parameters + ---------- + data : Union[xr.DataArray, float, int] + Data that should be cyclically encoded + data_max: Union[int, float] + Maximum possible value of input data. Should be greater than 0. + + Returns + ------- + data_cos: Union[xr.DataArray, float, int] + Cosine part of cyclically encoded input data + data_sin: Union[xr.DataArray, float, int] + Sine part of cyclically encoded input data + """ + + data_sin = np.sin((data / data_max) * 2 * np.pi) + data_cos = np.cos((data / data_max) * 2 * np.pi) + + return data_cos, data_sin + + +def get_latlon_coords_for_input(ds): + """Dummy function for getting lat and lon.""" + return ds[["lat", "lon"]].chunk(-1, -1) diff --git a/mllam_data_prep/ops/loading.py b/mllam_data_prep/ops/loading.py index 955fafd..f6bfc34 100644 --- a/mllam_data_prep/ops/loading.py +++ b/mllam_data_prep/ops/loading.py @@ -1,20 +1,20 @@ import xarray as xr -def load_and_subset_dataset(fp, variables): +def load_input_dataset(fp): """ - Load the dataset, subset the variables along the specified coordinates and - check coordinate units + Load the dataset Parameters ---------- fp : str Filepath to the source dataset, for example the path to a zarr dataset or a netCDF file (anything that is supported by `xarray.open_dataset` will work) - variables : dict - Dictionary with the variables to subset - with keys as the variable names and values with entries for each - coordinate and coordinate values to extract + + Returns + ------- + ds: xr.Dataset + Source dataset """ try: @@ -22,36 +22,4 @@ def load_and_subset_dataset(fp, variables): except ValueError: ds = xr.open_dataset(fp) - ds_subset = xr.Dataset() - ds_subset.attrs.update(ds.attrs) - if isinstance(variables, dict): - for var, coords_to_sample in variables.items(): - da = ds[var] - for coord, sampling in coords_to_sample.items(): - coord_values = sampling.values - try: - da = da.sel(**{coord: coord_values}) - except KeyError as ex: - raise KeyError( - f"Could not find the all coordinate values `{coord_values}` in " - f"coordinate `{coord}` in the dataset" - ) from ex - expected_units = sampling.units - coord_units = da[coord].attrs.get("units", None) - if coord_units is not None and coord_units != expected_units: - raise ValueError( - f"Expected units {expected_units} for coordinate {coord}" - f" in variable {var} but got {coord_units}" - ) - ds_subset[var] = da - elif isinstance(variables, list): - try: - ds_subset = ds[variables] - except KeyError as ex: - raise KeyError( - f"Could not find the all variables `{variables}` in the dataset. " - f"The available variables are {list(ds.data_vars)}" - ) from ex - else: - raise ValueError("The `variables` argument should be a list or a dictionary") - return ds_subset + return ds diff --git a/mllam_data_prep/ops/subsetting.py b/mllam_data_prep/ops/subsetting.py new file mode 100644 index 0000000..d2ba3a8 --- /dev/null +++ b/mllam_data_prep/ops/subsetting.py @@ -0,0 +1,57 @@ +def subset_dataset(ds_subset, ds_input, variables, chunking): + """ + Select specific variables from the provided the dataset, subset the + variables along the specified coordinates and check coordinate units + + Parameters + ---------- + ds_subset : xr.Dataset + Subset of ds_input + ds_input : xr.Dataset + Input/source dataset + variables : dict + Dictionary with the variables to subset + with keys as the variable names and values with entries for each + coordinate and coordinate values to extract + chunking: dict + Dictionary with keys as the dimensions to chunk along and values + with the chunk size + """ + + if isinstance(variables, dict): + for var, coords_to_sample in variables.items(): + da = ds_input[var] + for coord, sampling in coords_to_sample.items(): + coord_values = sampling.values + try: + da = da.sel(**{coord: coord_values}) + except KeyError as ex: + raise KeyError( + f"Could not find the all coordinate values `{coord_values}` in " + f"coordinate `{coord}` in the dataset" + ) from ex + expected_units = sampling.units + coord_units = da[coord].attrs.get("units", None) + if coord_units is not None and coord_units != expected_units: + raise ValueError( + f"Expected units {expected_units} for coordinate {coord}" + f" in variable {var} but got {coord_units}" + ) + ds_subset[var] = da + elif isinstance(variables, list): + try: + ds_subset = ds_input[variables] + except KeyError as ex: + raise KeyError( + f"Could not find the all variables `{variables}` in the dataset. " + f"The available variables are {list(ds_input.data_vars)}" + ) from ex + else: + raise ValueError("The `variables` argument should be a list or a dictionary") + + chunks = { + dim: chunking.get(dim, int(ds_subset[dim].count())) for dim in ds_subset.dims + } + ds_subset = ds_subset.chunk(chunks) + + return ds_subset diff --git a/tests/test_derived_variables.py b/tests/test_derived_variables.py new file mode 100644 index 0000000..786e064 --- /dev/null +++ b/tests/test_derived_variables.py @@ -0,0 +1,117 @@ +import datetime +import random +from unittest.mock import patch + +import isodate +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +import mllam_data_prep as mdp + +NCOORD = 10 +NTIME = 10 +LAT_MIN = -90 +LAT_MAX = 90 +LON_MIN = 0 +LON_MAX = 360 +LATITUDE = [ + 55.711, + xr.DataArray( + np.random.uniform(LAT_MIN, LAT_MAX, size=(NCOORD, NCOORD)), + dims=["x", "y"], + coords={"x": np.arange(NCOORD), "y": np.arange(NCOORD)}, + name="lat", + ), +] +LONGITUDE = [ + 12.564, + xr.DataArray( + np.random.uniform(LON_MIN, LON_MAX, size=(NCOORD, NCOORD)), + dims=["x", "y"], + coords={"x": np.arange(NCOORD), "y": np.arange(NCOORD)}, + name="lon", + ), +] +TIME = [ + np.datetime64("2004-06-11T00:00:00"), # invalid type + isodate.parse_datetime("1999-03-21T00:00"), + xr.DataArray( + pd.date_range( + start=isodate.parse_datetime("1999-03-21T00:00"), + periods=NTIME, + freq=isodate.parse_duration("PT1H"), + ), + dims=["time"], + name="time", + ), +] + + +def mock_cyclic_encoding(data, data_max): + """Mock the `cyclic_encoding` function from mllam_data_prep.ops.derived_variables.""" + if isinstance(data, xr.DataArray): + data_cos = xr.DataArray( + random.uniform(-1, 1), + coords=data.coords, + dims=data.dims, + ) + data_sin = xr.DataArray( + random.uniform(-1, 1), + coords=data.coords, + dims=data.dims, + ) + return data_cos, data_sin + elif isinstance(data, (float, int)): + return random.uniform(-1, 1), random.uniform(-1, 1) + + +@pytest.mark.parametrize("lat", LATITUDE) +@pytest.mark.parametrize("lon", LONGITUDE) +@pytest.mark.parametrize("time", TIME) +def test_toa_radiation(lat, lon, time): + """ + Test the `calculate_toa_radiation` function from mllam_data_prep.derived_variables + """ + with patch( + "mllam_data_prep.ops.derived_variables.cyclic_encoding", + side_effect=mock_cyclic_encoding, + ): + if isinstance(time, (xr.DataArray, datetime.datetime)): + mdp.ops.derived_variables.calculate_toa_radiation(lat, lon, time) + else: + with pytest.raises(TypeError): + mdp.ops.derived_variables.calculate_toa_radiation(lat, lon, time) + + +@pytest.mark.parametrize("time", TIME) +def test_hour_of_day(time): + """ + Test the `calculate_hour_of_day` function from mllam_data_prep.derived_variables + """ + with patch( + "mllam_data_prep.ops.derived_variables.cyclic_encoding", + side_effect=mock_cyclic_encoding, + ): + if isinstance(time, (xr.DataArray, datetime.datetime)): + mdp.ops.derived_variables.calculate_hour_of_day(time) + else: + with pytest.raises(TypeError): + mdp.ops.derived_variables.calculate_hour_of_day(time) + + +@pytest.mark.parametrize("time", TIME) +def test_day_of_year(time): + """ + Test the `calculate_day_of_year` function from mllam_data_prep.derived_variables + """ + with patch( + "mllam_data_prep.ops.derived_variables.cyclic_encoding", + side_effect=mock_cyclic_encoding, + ): + if isinstance(time, (xr.DataArray, datetime.datetime)): + mdp.ops.derived_variables.calculate_day_of_year(time) + else: + with pytest.raises(TypeError): + mdp.ops.derived_variables.calculate_day_of_year(time)