Skip to content

Commit

Permalink
Add optional 'attributes' to the config of 'derived_variables' and ch…
Browse files Browse the repository at this point in the history
…eck the attributes of the derived variable data-array
  • Loading branch information
ealerskans committed Dec 6, 2024
1 parent 9d2db07 commit 26455bc
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 21 deletions.
10 changes: 9 additions & 1 deletion mllam_data_prep/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class DerivedVariable:

kwargs: Dict[str, str]
function: str
attributes: Dict[str, Any] = None


@dataclass
Expand Down Expand Up @@ -148,7 +149,8 @@ class InputDataset:
1) the path to the dataset,
2) the expected dimensions of the dataset,
3) the variables to select from the dataset (and optionally subsection
along the coordinates for each variable) and finally
along the coordinates for each variable) and/or the variables to derive
from the dataset, and finally
4) the method by which the dimensions and variables of the dataset are
mapped to one of the output variables (this includes stacking of all
the selected variables into a new single variable along a new coordinate,
Expand Down Expand Up @@ -179,6 +181,12 @@ class InputDataset:
(e.g. two datasets that coincide in space and time will only differ in the feature dimension,
so the two will be combined by concatenating along the feature dimension).
If a single shared coordinate cannot be found then an exception will be raised.
derived_variables: Dict[str, DerivedVariable]
Dictionary of variables to derive from the dataset, where the keys are the variable names and
the values are dictionaries defining the necessary function and kwargs. E.g.
`{"toa_radiation": {"kwargs": {"time": "time", "lat": "lat", "lon": "lon"}, "function": "calculate_toa_radiation"}}`
would derive the "toa_radiation" variable using the `calculate_toa_radiation` function, which
takes `time`, `lat` and `lon` as arguments.
"""

path: str
Expand Down
1 change: 1 addition & 0 deletions mllam_data_prep/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def create_dataset(config: Config):
except Exception as ex:
raise Exception(
f"Error loading dataset {dataset_name} from {path}"
f" or deriving variables '{', '.join(list(derived_variables.keys()))}'."
) from ex
_check_dataset_attributes(
ds=ds,
Expand Down
172 changes: 152 additions & 20 deletions mllam_data_prep/derived_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def derive_variables(fp, derived_variables, chunking):
for _, derived_variable in derived_variables.items():
required_variables = derived_variable.kwargs
function_name = derived_variable.function
derived_variable_attributes = derived_variable.attributes or {}
ds_input = ds[required_variables.keys()]

# Any coordinates needed for the derivation, for which chunking should be performed,
Expand All @@ -61,35 +62,50 @@ def derive_variables(fp, derived_variables, chunking):

# Calculate the derived variable
kwargs = {v: ds_input[k] for k, v in required_variables.items()}
func = get_derived_variable_function(function_name)
func = _get_derived_variable_function(function_name)
derived_field = func(**kwargs)

# Some of the derived variables include two components, since
# they are cyclically encoded (cos and sin parts)
# Check the derived field(s)
derived_field = _check_field(
derived_field,
derived_variable_attributes,
ds_input,
required_coordinates,
chunks,
)

# Add the derived field(s) to the subset
if isinstance(derived_field, xr.DataArray):
derived_field = _return_dropped_coordinates(
derived_field, ds_input, required_coordinates, chunks
)
ds_subset[derived_field.name] = derived_field
elif isinstance(derived_field, tuple):
elif isinstance(derived_field, tuple) and all(
isinstance(field, xr.DataArray) for field in derived_field
):
for field in derived_field:
field = _return_dropped_coordinates(
field, ds_input, required_coordinates, chunks
)
ds_subset[field.name] = field
else:
raise TypeError(
"Expected an instance of xr.DataArray or tuple(xr.DataArray),"
f" but got {type(derived_field)}."
)

return ds_subset


def get_derived_variable_function(function_namespace):
def _get_derived_variable_function(function_namespace):
"""
Function for returning the function to be used to derive
Function for getting the function for deriving
the specified variable.
1. Check if the function to use is in globals()
2. If it is in globals then call it
3. If it isn't in globals() then import the necessary module
before calling it
Parameters
----------
function_namespace: str
The full function namespace or just the function name
if it is a function included in this module.
Returns
-------
function: object
Function for deriving the specified variable
"""
# Get the name of the calling module
calling_module = globals()["__name__"]
Expand Down Expand Up @@ -127,13 +143,111 @@ def get_derived_variable_function(function_namespace):
return function


def _return_dropped_coordinates(derived_field, ds_input, required_coordinates, chunks):
"""Return coordinates that have been reset."""
def _check_field(
derived_field, derived_field_attributes, ds_input, required_coordinates, chunks
):
"""
Check the derived field.
Parameters
----------
derived_field: xr.DataArray or tuple
The derived variable
derived_field_attributes: dict
Dictionary with attributes for the derived variables.
Defined in the config file.
ds_input: xr.Dataset
xarray dataset with variables needed to derive the specified variable
required_coordinates: list
List of coordinates required for deriving the specified variable
chunks: dict
Dictionary with keys as the dimensions to chunk along and values
with the chunk size, only inbcluding the dimensions that are included
in the output as well.
Returns
-------
derived_field: xr.DataArray or tuple
The derived field
"""
if isinstance(derived_field, xr.DataArray):
derived_field = _check_attributes(derived_field, derived_field_attributes)
derived_field = _return_dropped_coordinates(
derived_field, ds_input, required_coordinates, chunks
)
elif isinstance(derived_field, tuple) and all(
isinstance(field, xr.DataArray) for field in derived_field
):
for field in derived_field:
field = _check_attributes(field, derived_field_attributes)
field = _return_dropped_coordinates(
field, ds_input, required_coordinates, chunks
)
else:
raise TypeError(
"Expected an instance of xr.DataArray or tuple(xr.DataArray),"
f" but got {type(derived_field)}."
)

return derived_field


def _check_attributes(field, field_attributes):
"""
Check the attributes of the derived variable.
Parameters
----------
field: xr.DataArray or tuple
The derived field
field_attributes: dict
Dictionary with attributes for the derived variables.
Defined in the config file.
Returns
-------
field: xr.DataArray or tuple
The derived field
"""
for attribute in ["units", "long_name"]:
if attribute not in field.attrs or field.attrs[attribute] is None:
if attribute in field_attributes.keys():
field.attrs[attribute] = field_attributes[attribute]
else:
# The expected attributes are empty and the attributes have not been
# set during the calculation of the derived variable
raise ValueError(
f"The attribute '{attribute}' has not been set for the derived"
f" variable '{field.name}' (most likely because you are using a"
" function external to `mlllam-data-prep` to derive the field)."
" This attribute has not been defined in the 'attributes' section"
" of the config file either. Make sure that you add it to the"
f" 'attributes' section of the derived variable '{field.name}'."
)
else:
if attribute in field_attributes.keys():
logger.warning(
f"The attribute '{attribute}' of the derived field"
f" {field.name} is being overwritten from"
f" '{field.attrs[attribute]}' to"
f" '{field_attributes[attribute]}' according"
" to specification in the config file."
)
field.attrs[attribute] = field_attributes[attribute]
else:
# Attributes are set and nothing has been defined in the config file
pass

return field


def _return_dropped_coordinates(field, ds_input, required_coordinates, chunks):
"""Return the coordinates that have been reset."""
for req_coord in required_coordinates:
if req_coord in chunks:
derived_field.coords[req_coord] = ds_input[req_coord]
field.coords[req_coord] = ds_input[req_coord]

return derived_field
return field


def calculate_toa_radiation(lat, lon, time):
Expand Down Expand Up @@ -179,6 +293,8 @@ def calculate_toa_radiation(lat, lon, time):
if isinstance(toa_radiation, xr.DataArray):
# Add attributes
toa_radiation.name = "toa_radiation"
toa_radiation.attrs["long_name"] = "top-of-the-atmosphere radiation"
toa_radiation.attrs["units"] = "W*m**-2"

return toa_radiation

Expand Down Expand Up @@ -210,10 +326,18 @@ def calculate_hour_of_day(time):
if isinstance(hour_of_day_cos, xr.DataArray):
# Add attributes
hour_of_day_cos.name = "hour_of_day_cos"
hour_of_day_cos.attrs[
"long_name"
] = "Cosine component of cyclically encoded hour of day"
hour_of_day_cos.attrs["units"] = "1"

if isinstance(hour_of_day_sin, xr.DataArray):
# Add attributes
hour_of_day_sin.name = "hour_of_day_sin"
hour_of_day_sin.attrs[
"long_name"
] = "Sine component of cyclically encoded hour of day"
hour_of_day_sin.attrs["units"] = "1"

return hour_of_day_cos, hour_of_day_sin

Expand Down Expand Up @@ -245,10 +369,18 @@ def calculate_day_of_year(time):
if isinstance(day_of_year_cos, xr.DataArray):
# Add attributes
day_of_year_cos.name = "day_of_year_cos"
day_of_year_cos.attrs[
"long_name"
] = "Cosine component of cyclically encoded day of year"
day_of_year_cos.attrs["units"] = "1"

if isinstance(day_of_year_sin, xr.DataArray):
# Add attributes
day_of_year_sin.name = "day_of_year_sin"
day_of_year_sin.attrs[
"long_name"
] = "Sine component of cyclically encoded day of year"
day_of_year_sin.attrs["units"] = "1"

return day_of_year_cos, day_of_year_sin

Expand Down

0 comments on commit 26455bc

Please sign in to comment.