Skip to content

Commit

Permalink
Add dummy function for getting lat,lon (preparation for mllam#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
ealerskans committed Dec 9, 2024
1 parent 26455bc commit fbb6065
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions mllam_data_prep/derived_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,25 @@ def derive_variables(fp, derived_variables, chunking):
ds_subset = xr.Dataset()
ds_subset.attrs.update(ds.attrs)
for _, derived_variable in derived_variables.items():
required_variables = derived_variable.kwargs
required_kwargs = derived_variable.kwargs
function_name = derived_variable.function
derived_variable_attributes = derived_variable.attributes or {}
ds_input = ds[required_variables.keys()]

# Separate the lat,lon from the required variables as these will be derived separately
latlon_coords_to_include = {}
for k, v in list(required_kwargs.items()):
if k in ["lat", "lon"]:
latlon_coords_to_include[k] = required_kwargs.pop(k)

# Subset the dataset
ds_input = ds[required_kwargs.keys()]

# Any coordinates needed for the derivation, for which chunking should be performed,
# should be converted to variables since it is not possible for coordinates to be
# chunked dask arrays
# should be converted to variables since it is not possible for *indexed* coordinates
# to be chunked dask arrays
chunks = {d: chunking.get(d, int(ds_input[d].count())) for d in ds_input.dims}
required_coordinates = [
req_var
for req_var in required_variables.keys()
if req_var in ds_input.coords
req_var for req_var in required_kwargs.keys() if req_var in ds_input.coords
]
ds_input = ds_input.drop_indexes(required_coordinates, errors="ignore")
for req_coord in required_coordinates:
Expand All @@ -60,9 +66,15 @@ def derive_variables(fp, derived_variables, chunking):
# Chunk the data variables
ds_input = ds_input.chunk(chunks)

# Calculate the derived variable
kwargs = {v: ds_input[k] for k, v in required_variables.items()}
# Add function arguments to kwargs
kwargs = {}
if len(latlon_coords_to_include):
latlon = get_latlon_coords_for_input(ds)
for k, v in latlon_coords_to_include.items():
kwargs[v] = latlon[k]
kwargs.update({v: ds_input[k] for k, v in required_kwargs.items()})
func = _get_derived_variable_function(function_name)
# Calculate the derived variable
derived_field = func(**kwargs)

# Check the derived field(s)
Expand Down Expand Up @@ -408,3 +420,8 @@ def cyclic_encoding(data_array, da_max):
data_array_cos = np.cos((data_array / da_max) * 2 * np.pi)

return data_array_cos, data_array_sin


def get_latlon_coords_for_input(ds_input):
"""Dummy function for getting lat and lon."""
return ds_input[["lat", "lon"]].chunk(-1, -1)

0 comments on commit fbb6065

Please sign in to comment.