diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index e2d2140..b9de2da 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -193,11 +193,9 @@ def get_dataarray( """ Return the processed data (as a single `xr.DataArray`) for the given category of data and test/train/val-split that covers all the data (in - space and time) of a given category (state/forcing/static). A - datastore must be able to return for the "state" category, but - "forcing" and "static" are optional (in which case the method should - return `None`). For the "static" category the `split` is allowed to be - `None` because the static data is the same for all splits. + space and time) of a given category (state/forcing/static). For the + "static" category the `split` is allowed to be `None` because the static + data is the same for all splits. The returned dataarray is expected to at minimum have dimensions of `(grid_index, {category}_feature)` so that any spatial dimensions have diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index f68bb4d..8f48891 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -218,9 +218,8 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: """ Return the processed data (as a single `xr.DataArray`) for the given category of data and test/train/val-split that covers all the data (in - space and time) of a given category (state/forcin g/static). "state" is - the only required category, for other categories, the method will - return `None` if the category is not found in the datastore. + space and time) of a given category (state/forcing/static). The method + will return `None` if the category is not found in the datastore. The returned dataarray will at minimum have dimensions of `(grid_index, {category}_feature)` so that any spatial dimensions have been stacked diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index a958b8f..1bdbc8c 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -300,11 +300,9 @@ def get_dataarray( """ Return the processed data (as a single `xr.DataArray`) for the given category of data and test/train/val-split that covers all the data (in - space and time) of a given category (state/forcing/static). A - datastore must be able to return for the "state" category, but - "forcing" and "static" are optional (in which case the method should - return `None`). For the "static" category the `split` is allowed to be - `None` because the static data is the same for all splits. + space and time) of a given category (state/forcing/static). For the + "static" category the `split` is allowed to be `None` because the static + data is the same for all splits. The returned dataarray is expected to at minimum have dimensions of `(grid_index, {category}_feature)` so that any spatial dimensions have