Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for writing more composite statistics (e.g. grid-point based mean of time-step differences) #42

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
.vscode/

# C extensions
*.so
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ python -m pip install mllam-data-prep[dask-distributed]

## Developing `mllam-data-prep`

To work on developing `mllam-data-prep` it easiest to install and manage the dependencies with [pdm](https://pdm.fming.dev/). To get started clone your fork of [the main repo](https://github.com/mllam/mllam-data-prep) locally:
To work on developing `mllam-data-prep` it is easiest to install and manage the dependencies with [pdm](https://pdm.fming.dev/). To get started clone your fork of [the main repo](https://github.com/mllam/mllam-data-prep) locally:

```bash
git clone https://github.com/<your-github-username>/mllam-data-prep
Expand All @@ -41,7 +41,7 @@ pdm use --venv in-project
pdm install
```

All the linting is handelled by `pre-commit` which can be setup to automatically be run on each `git commit` by installing the git commit hook:
All the linting is handled by `pre-commit` which can be setup to automatically be run on each `git commit` by installing the git commit hook:

```bash
pdm run pre-commit install
Expand Down Expand Up @@ -256,7 +256,7 @@ The `output` section defines three things:

1. `variables`: what input variables the model architecture you are targeting expects, and what the dimensions are for each of these variables.
2. `coord_ranges`: the range of values for each of the dimensions that the model architecture expects as input. These are optional, but allows you to ensure that the training dataset is created with the correct range of values for each dimension.
3. `chunking`: the chunk sizes to use when writing the training dataset to zarr. This is optional, but can be used to optimise the performance of the zarr dataset. By default the chunk sizes are set to the size of the dimension, but this can be overridden by setting the chunk size in the configuration file. A common choice is to set the dimension along which you are batching to align with the of each training item (e.g. if you are training a model with time-step roll-out of 10 timesteps, you might choose a chunksize of 10 along the time dimension).
3. `chunking`: the chunk sizes to use when writing the training dataset to zarr. This is optional, but can be used to optimise the performance of the zarr dataset. By default the chunk sizes are set to the size of the dimension, but this can be overridden by setting the chunk size in the configuration file. A common choice is to set the dimension along which you are batching to align with that of each training item (e.g. if you are training a model with time-step roll-out of 10 timesteps, you might choose a chunksize of 10 along the time dimension).
4. Splitting and calculation of statistics of the output variables, using the `splitting` section. The `output.splitting.splits` attribute defines the individual splits to create (for example `train`, `val` and `test`) and `output.splitting.dim` defines the dimension to split along. The `compute_statistics` can be optionally set for a given split to calculate the statistical properties requested (for example `mean`, `std`) any method available on `xarray.Dataset.{op}` can be used. In addition methods prefixed by `diff_` (so the operational would be listed as `diff_{op}`) to compute a statistic based on difference of consecutive time-steps, e.g. `diff_mean` to compute the `mean` of the difference between consecutive timesteps (these are used for normalisating increments). The `dims` attribute defines the dimensions to calculate the statistics over (for example `grid_index` and `time`).

### The `inputs` section
Expand Down
14 changes: 12 additions & 2 deletions example.danra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,18 @@ output:
start: 1990-09-03T00:00
end: 1990-09-06T00:00
compute_statistics:
ops: [mean, std, diff_mean, diff_std]
dims: [grid_index, time]
- mean
- mean_per_gridpoint
- std
- std_per_gridpoint
- diff_mean
- diff_mean_per_gridpoint
- diff_std
- diff_std_per_gridpoint
- diurnal_diff_mean
- diurnal_diff_mean_per_gridpoint
- diurnal_diff_std
- diurnal_diff_std_per_gridpoint
val:
start: 1990-09-06T00:00
end: 1990-09-07T00:00
Expand Down
32 changes: 1 addition & 31 deletions mllam_data_prep/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,6 @@ class DimMapping:
E.g. `{"grid_index": {"method": "stack", "dims": ["x", "y"]}}` will stack the "x" and "y"
dimensions in the input dataset into a new "grid_index" dimension in the output.

Attributes:
method: The method used for mapping.
dims: The dimensions to be mapped.
name_format: The format for naming the mapped dimensions.

Attributes
----------
method: str
Expand Down Expand Up @@ -161,25 +156,6 @@ class InputDataset:
attributes: Dict[str, Any] = None


@dataclass
class Statistics:
"""
Define the statistics to compute for the output dataset, this includes defining
the the statistics to compute and the dimensions to compute the statistics over.
The statistics will be computed for each variable in the output dataset seperately.

Attributes
----------
ops: List[str]
The statistics to compute, e.g. ["mean", "std", "min", "max"].
dims: List[str]
The dimensions to compute the statistics over, e.g. ["time", "grid_index"].
"""

ops: List[str]
dims: List[str]


@dataclass
class Split:
"""
Expand All @@ -198,7 +174,7 @@ class Split:

start: str
end: str
compute_statistics: Statistics = None
compute_statistics: List[str] = None


@dataclass
Expand Down Expand Up @@ -266,12 +242,6 @@ class Output:
class Config(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard):
"""Configuration for the model.

Attributes:
schema_version: Version of the config file schema.
dataset_version: Version of the dataset itself.
architecture: Information about the model architecture this dataset is intended for.
inputs: Input datasets for the model.

Attributes
----------
output: Output
Expand Down
12 changes: 6 additions & 6 deletions mllam_data_prep/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,15 @@ def create_dataset(config: Config):
logger.info(f"Computing statistics for split {split_name}")
split_stats = calc_stats(
ds=ds_split,
statistics_config=split_config.compute_statistics,
splitting_dim=splitting.dim,
statistic_methods=split_config.compute_statistics,
)
for op, op_dataarrays in split_stats.items():
for var_name, da in op_dataarrays.items():
for op, ds_op in split_stats.items():
for var_name, da in ds_op.items():
ds[f"{var_name}__{split_name}__{op}"] = da

# add a new variable which contains the start, stop for each split, the coords would then be the split names
# and the data would be the start, stop values
# Add a new variable which contains the start, stop for each split,
# the coords would then be the split names and the data would be the
# start, stop values
split_vals = np.array([[split.start, split.end] for split in splits.values()])
da_splits = xr.DataArray(
split_vals,
Expand Down
Loading
Loading