diff --git a/mllam_data_prep/config.py b/mllam_data_prep/config.py index be72de9..c6192d1 100644 --- a/mllam_data_prep/config.py +++ b/mllam_data_prep/config.py @@ -78,6 +78,7 @@ class DerivedVariable: kwargs: Dict[str, str] function: str + attributes: Dict[str, Any] = None @dataclass @@ -148,7 +149,8 @@ class InputDataset: 1) the path to the dataset, 2) the expected dimensions of the dataset, 3) the variables to select from the dataset (and optionally subsection - along the coordinates for each variable) and finally + along the coordinates for each variable) and/or the variables to derive + from the dataset, and finally 4) the method by which the dimensions and variables of the dataset are mapped to one of the output variables (this includes stacking of all the selected variables into a new single variable along a new coordinate, @@ -179,6 +181,12 @@ class InputDataset: (e.g. two datasets that coincide in space and time will only differ in the feature dimension, so the two will be combined by concatenating along the feature dimension). If a single shared coordinate cannot be found then an exception will be raised. + derived_variables: Dict[str, DerivedVariable] + Dictionary of variables to derive from the dataset, where the keys are the variable names and + the values are dictionaries defining the necessary function and kwargs. E.g. + `{"toa_radiation": {"kwargs": {"time": "time", "lat": "lat", "lon": "lon"}, "function": "calculate_toa_radiation"}}` + would derive the "toa_radiation" variable using the `calculate_toa_radiation` function, which + takes `time`, `lat` and `lon` as arguments. """ path: str diff --git a/mllam_data_prep/create_dataset.py b/mllam_data_prep/create_dataset.py index 1a2f389..4ce5e14 100644 --- a/mllam_data_prep/create_dataset.py +++ b/mllam_data_prep/create_dataset.py @@ -148,6 +148,7 @@ def create_dataset(config: Config): except Exception as ex: raise Exception( f"Error loading dataset {dataset_name} from {path}" + f" or deriving variables '{', '.join(list(derived_variables.keys()))}'." ) from ex _check_dataset_attributes( ds=ds, diff --git a/mllam_data_prep/derived_variables.py b/mllam_data_prep/derived_variables.py index 760e0b3..cda1bdf 100644 --- a/mllam_data_prep/derived_variables.py +++ b/mllam_data_prep/derived_variables.py @@ -40,6 +40,7 @@ def derive_variables(fp, derived_variables, chunking): for _, derived_variable in derived_variables.items(): required_variables = derived_variable.kwargs function_name = derived_variable.function + derived_variable_attributes = derived_variable.attributes or {} ds_input = ds[required_variables.keys()] # Any coordinates needed for the derivation, for which chunking should be performed, @@ -61,35 +62,50 @@ def derive_variables(fp, derived_variables, chunking): # Calculate the derived variable kwargs = {v: ds_input[k] for k, v in required_variables.items()} - func = get_derived_variable_function(function_name) + func = _get_derived_variable_function(function_name) derived_field = func(**kwargs) - # Some of the derived variables include two components, since - # they are cyclically encoded (cos and sin parts) + # Check the derived field(s) + derived_field = _check_field( + derived_field, + derived_variable_attributes, + ds_input, + required_coordinates, + chunks, + ) + + # Add the derived field(s) to the subset if isinstance(derived_field, xr.DataArray): - derived_field = _return_dropped_coordinates( - derived_field, ds_input, required_coordinates, chunks - ) ds_subset[derived_field.name] = derived_field - elif isinstance(derived_field, tuple): + elif isinstance(derived_field, tuple) and all( + isinstance(field, xr.DataArray) for field in derived_field + ): for field in derived_field: - field = _return_dropped_coordinates( - field, ds_input, required_coordinates, chunks - ) ds_subset[field.name] = field + else: + raise TypeError( + "Expected an instance of xr.DataArray or tuple(xr.DataArray)," + f" but got {type(derived_field)}." + ) return ds_subset -def get_derived_variable_function(function_namespace): +def _get_derived_variable_function(function_namespace): """ - Function for returning the function to be used to derive + Function for getting the function for deriving the specified variable. - 1. Check if the function to use is in globals() - 2. If it is in globals then call it - 3. If it isn't in globals() then import the necessary module - before calling it + Parameters + ---------- + function_namespace: str + The full function namespace or just the function name + if it is a function included in this module. + + Returns + ------- + function: object + Function for deriving the specified variable """ # Get the name of the calling module calling_module = globals()["__name__"] @@ -127,13 +143,111 @@ def get_derived_variable_function(function_namespace): return function -def _return_dropped_coordinates(derived_field, ds_input, required_coordinates, chunks): - """Return coordinates that have been reset.""" +def _check_field( + derived_field, derived_field_attributes, ds_input, required_coordinates, chunks +): + """ + Check the derived field. + + Parameters + ---------- + derived_field: xr.DataArray or tuple + The derived variable + derived_field_attributes: dict + Dictionary with attributes for the derived variables. + Defined in the config file. + ds_input: xr.Dataset + xarray dataset with variables needed to derive the specified variable + required_coordinates: list + List of coordinates required for deriving the specified variable + chunks: dict + Dictionary with keys as the dimensions to chunk along and values + with the chunk size, only inbcluding the dimensions that are included + in the output as well. + + Returns + ------- + derived_field: xr.DataArray or tuple + The derived field + """ + if isinstance(derived_field, xr.DataArray): + derived_field = _check_attributes(derived_field, derived_field_attributes) + derived_field = _return_dropped_coordinates( + derived_field, ds_input, required_coordinates, chunks + ) + elif isinstance(derived_field, tuple) and all( + isinstance(field, xr.DataArray) for field in derived_field + ): + for field in derived_field: + field = _check_attributes(field, derived_field_attributes) + field = _return_dropped_coordinates( + field, ds_input, required_coordinates, chunks + ) + else: + raise TypeError( + "Expected an instance of xr.DataArray or tuple(xr.DataArray)," + f" but got {type(derived_field)}." + ) + + return derived_field + + +def _check_attributes(field, field_attributes): + """ + Check the attributes of the derived variable. + + Parameters + ---------- + field: xr.DataArray or tuple + The derived field + field_attributes: dict + Dictionary with attributes for the derived variables. + Defined in the config file. + + Returns + ------- + field: xr.DataArray or tuple + The derived field + """ + for attribute in ["units", "long_name"]: + if attribute not in field.attrs or field.attrs[attribute] is None: + if attribute in field_attributes.keys(): + field.attrs[attribute] = field_attributes[attribute] + else: + # The expected attributes are empty and the attributes have not been + # set during the calculation of the derived variable + raise ValueError( + f"The attribute '{attribute}' has not been set for the derived" + f" variable '{field.name}' (most likely because you are using a" + " function external to `mlllam-data-prep` to derive the field)." + " This attribute has not been defined in the 'attributes' section" + " of the config file either. Make sure that you add it to the" + f" 'attributes' section of the derived variable '{field.name}'." + ) + else: + if attribute in field_attributes.keys(): + logger.warning( + f"The attribute '{attribute}' of the derived field" + f" {field.name} is being overwritten from" + f" '{field.attrs[attribute]}' to" + f" '{field_attributes[attribute]}' according" + " to specification in the config file." + ) + field.attrs[attribute] = field_attributes[attribute] + else: + # Attributes are set and nothing has been defined in the config file + pass + + return field + + +def _return_dropped_coordinates(field, ds_input, required_coordinates, chunks): + """Return the coordinates that have been reset.""" for req_coord in required_coordinates: if req_coord in chunks: - derived_field.coords[req_coord] = ds_input[req_coord] + field.coords[req_coord] = ds_input[req_coord] - return derived_field + return field def calculate_toa_radiation(lat, lon, time): @@ -179,6 +293,8 @@ def calculate_toa_radiation(lat, lon, time): if isinstance(toa_radiation, xr.DataArray): # Add attributes toa_radiation.name = "toa_radiation" + toa_radiation.attrs["long_name"] = "top-of-the-atmosphere radiation" + toa_radiation.attrs["units"] = "W*m**-2" return toa_radiation @@ -210,10 +326,18 @@ def calculate_hour_of_day(time): if isinstance(hour_of_day_cos, xr.DataArray): # Add attributes hour_of_day_cos.name = "hour_of_day_cos" + hour_of_day_cos.attrs[ + "long_name" + ] = "Cosine component of cyclically encoded hour of day" + hour_of_day_cos.attrs["units"] = "1" if isinstance(hour_of_day_sin, xr.DataArray): # Add attributes hour_of_day_sin.name = "hour_of_day_sin" + hour_of_day_sin.attrs[ + "long_name" + ] = "Sine component of cyclically encoded hour of day" + hour_of_day_sin.attrs["units"] = "1" return hour_of_day_cos, hour_of_day_sin @@ -245,10 +369,18 @@ def calculate_day_of_year(time): if isinstance(day_of_year_cos, xr.DataArray): # Add attributes day_of_year_cos.name = "day_of_year_cos" + day_of_year_cos.attrs[ + "long_name" + ] = "Cosine component of cyclically encoded day of year" + day_of_year_cos.attrs["units"] = "1" if isinstance(day_of_year_sin, xr.DataArray): # Add attributes day_of_year_sin.name = "day_of_year_sin" + day_of_year_sin.attrs[ + "long_name" + ] = "Sine component of cyclically encoded day of year" + day_of_year_sin.attrs["units"] = "1" return day_of_year_cos, day_of_year_sin