forked from mllam/mllam-data-prep
-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_dataset.py
311 lines (263 loc) · 12 KB
/
create_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import datetime
import shutil
from collections import defaultdict
from pathlib import Path
import numpy as np
import xarray as xr
from loguru import logger
from numcodecs import Blosc
from . import __version__
from .config import Config, InvalidConfigException
from .ops.derived_variables import derive_variables
from .ops.loading import load_input_dataset
from .ops.mapping import map_dims_and_variables
from .ops.selection import select_by_kwargs
from .ops.statistics import calc_stats
from .ops.subsetting import subset_dataset
# the `extra` field in the config that was added between v0.2.0 and v0.5.0 is
# optional, so we can support both v0.2.0 and v0.5.0
SUPPORTED_CONFIG_VERSIONS = ["v0.2.0", "v0.5.0"]
def _check_dataset_attributes(ds, expected_attributes, dataset_name):
# check that the dataset has the expected attributes with the expected values
missing_attributes = set(expected_attributes.keys()) - set(ds.attrs.keys())
if len(missing_attributes) > 0:
raise ValueError(
f"Dataset {dataset_name} is missing the following attributes: {missing_attributes}"
)
# check for attributes having the wrong value
incorrect_attributes = {
key: val for key, val in expected_attributes.items() if ds.attrs[key] != val
}
if len(incorrect_attributes) > 0:
s_list = "\n".join(
[
f"{key}: {val} != {ds.attrs[key]}"
for key, val in incorrect_attributes.items()
]
)
raise ValueError(
f"Dataset {dataset_name} has the following incorrect attributes: {s_list}"
)
def _merge_dataarrays_by_target(dataarrays_by_target):
attrs_to_keep = ["source_dataset"]
dataarrays = []
for target, das in dataarrays_by_target.items():
logger.info(f"Merging dataarrays for target variable `{target}`")
concat_dim = None
for da in das:
d = da.attrs.get("variables_mapping_dim", None)
if d is None:
raise ValueError(
f"Dataarray for target {target} does not have the 'variables_mapping_dim' attribute"
)
if concat_dim is not None and d != concat_dim:
raise ValueError(
f"Dataarrays for target {target} have different 'variables_mapping_dim' attributes: {d} != {concat_dim}"
)
concat_dim = d
for da in das:
for attr in attrs_to_keep:
# create a aux coord for each attribute we want to keep
# (for example the name of the source dataset)
# so that we have this in the resulting dataset
da.coords[f"{concat_dim}_{attr}"] = xr.DataArray(
[da.attrs.pop(attr)] * int(da[concat_dim].count()),
dims=[concat_dim],
)
da_target = xr.concat(das, dim=concat_dim)
da_target.name = target
dataarrays.append(da_target)
# by doing a merge with join="exact" we make sure that the dataarrays
# are aligned along the same dimensions, and that the coordinates are
# the same for all dataarrays. Otherwise xarray will fill in with NaNs
# for any missing coordinate values
try:
ds = xr.merge(dataarrays, join="exact")
except ValueError as ex:
if ex.args[0].startswith("cannot align objects with join='exact'"):
raise InvalidConfigException(
f"Couldn't merge together the dataarrays for all targets ({', '.join(dataarrays_by_target.keys())})"
f" This is likely because the dataarrays have different dimensions or coordinates."
" Maybe you need to give the 'feature' dimension a unique name for each target variable?"
) from ex
else:
raise ex
return ds
def create_dataset(config: Config):
"""
Create a dataset from the input datasets specified in the config file.
Parameters
----------
config : Config
The configuration object defining the input datasets and how to map them to the output dataset.
Returns
-------
xr.Dataset
The dataset created from the input datasets with a variable for each output
as defined in the config file.
"""
if not config.schema_version in SUPPORTED_CONFIG_VERSIONS:
raise ValueError(
f"Unsupported schema version {config.schema_version}. Only schema versions "
f" {', '.join(SUPPORTED_CONFIG_VERSIONS)} are supported by mllam-data-prep "
f"v{__version__}."
)
if config.schema_version == "v0.2.0" and config.extra is not None:
raise ValueError(
"Config schema version v0.2.0 does not support the `extra` field. Please "
"update the schema version used in your config to v0.5.0."
)
output_config = config.output
output_coord_ranges = output_config.coord_ranges
chunking_config = config.output.chunking
dataarrays_by_target = defaultdict(list)
for dataset_name, input_config in config.inputs.items():
path = input_config.path
variables = input_config.variables
derived_variables = input_config.derived_variables
target_output_var = input_config.target_output_variable
expected_input_attributes = input_config.attributes
expected_input_var_dims = input_config.dims
output_dims = output_config.variables[target_output_var]
logger.info(f"Loading dataset {dataset_name} from {path}")
try:
ds_input = load_input_dataset(fp=path)
except Exception as ex:
raise Exception(f"Error loading dataset {dataset_name} from {path}") from ex
# Initialize the output dataset and add dimensions
ds = xr.Dataset()
ds.attrs.update(ds_input.attrs)
for dim in ds_input.dims:
ds = ds.assign_coords({dim: ds_input.coords[dim]})
if variables:
logger.info(f"Subsetting dataset {dataset_name}")
ds = subset_dataset(
ds_subset=ds,
ds_input=ds_input,
variables=variables,
chunking=chunking_config,
)
if derived_variables:
logger.info(f"Deriving variables from {dataset_name}")
ds = derive_variables(
ds=ds,
ds_input=ds_input,
derived_variables=derived_variables,
chunking=chunking_config,
)
_check_dataset_attributes(
ds=ds,
expected_attributes=expected_input_attributes,
dataset_name=dataset_name,
)
dim_mapping = input_config.dim_mapping
# check that there is an entry for each arch dimension
# in the dim_mapping so that we know how to construct the
# final dataset
missing_dims = set(output_dims) - set(dim_mapping.keys())
if missing_dims:
raise ValueError(
f"Missing dimension mapping for {missing_dims}"
f" for input dataset {dataset_name}, please provide"
" a mapping for all output dimensions by"
" using the 'dim_mapping' key in the input dataset"
)
logger.info(
f"Mapping dimensions and variables for dataset {dataset_name} to {target_output_var}"
)
try:
da_target = map_dims_and_variables(
ds=ds,
dim_mapping=dim_mapping,
expected_input_var_dims=expected_input_var_dims,
)
except Exception as ex:
raise Exception(
f"There was an issue stacking dimensions and variables to"
f" produce variable {target_output_var} from dataset {dataset_name}"
) from ex
da_target.attrs["source_dataset"] = dataset_name
# only need to do selection for the coordinates that the input dataset actually has
if output_coord_ranges is not None:
selection_kwargs = {}
for dim in output_dims:
if dim in output_coord_ranges:
selection_kwargs[dim] = output_coord_ranges[dim]
da_target = select_by_kwargs(da_target, **selection_kwargs)
dataarrays_by_target[target_output_var].append(da_target)
ds = _merge_dataarrays_by_target(dataarrays_by_target=dataarrays_by_target)
# need to drop the encoding so that we can write to zarr with new chunksizes
ds = ds.drop_encoding()
# default to making a single chunk for each dimension if chunksize is not specified
# in the config
logger.info(f"Chunking dataset with {chunking_config}")
chunks = {dim: chunking_config.get(dim, int(ds[dim].count())) for dim in ds.dims}
ds = ds.chunk(chunks)
splitting = config.output.splitting
if splitting is not None:
splits = splitting.splits
logger.info(
f"Setting splitting information to define `{list(splits.keys())}` splits "
f"along dimension `{splitting.dim}`"
)
for split_name, split_config in splits.items():
if split_config.compute_statistics is not None:
ds_split = ds.sel(
{splitting.dim: slice(split_config.start, split_config.end)}
)
logger.info(f"Computing statistics for split {split_name}")
split_stats = calc_stats(
ds=ds_split,
statistics_config=split_config.compute_statistics,
splitting_dim=splitting.dim,
)
for op, op_dataarrays in split_stats.items():
for var_name, da in op_dataarrays.items():
ds[f"{var_name}__{split_name}__{op}"] = da
# add a new variable which contains the start, stop for each split, the coords would then be the split names
# and the data would be the start, stop values
split_vals = np.array([[split.start, split.end] for split in splits.values()])
da_splits = xr.DataArray(
split_vals,
dims=["split_name", "split_part"],
coords={"split_name": list(splits.keys()), "split_part": ["start", "end"]},
)
ds["splits"] = da_splits
ds.attrs = {}
ds.attrs["schema_version"] = config.schema_version
ds.attrs["dataset_version"] = config.dataset_version
ds.attrs["created_on"] = datetime.datetime.now().replace(microsecond=0).isoformat()
ds.attrs[
"created_with"
] = "mllam-data-prep (https://github.com/mllam/mllam-data-prep)"
ds.attrs["mdp_version"] = f"v{__version__}"
return ds
def create_dataset_zarr(fp_config, fp_zarr: str = None):
"""
Create a dataset from the input datasets specified in the config file and write it to a zarr file.
The path to the zarr file is the same as the config file, but with the extension changed to '.zarr'.
Parameters
----------
fp_config : Path
The path to the configuration file.
fp_zarr : Path, optional
The path to the zarr file to write the dataset to. If not provided, the zarr file will be written
to the same directory as the config file with the extension changed to '.zarr'.
"""
config = Config.load_config(file=fp_config)
ds = create_dataset(config=config)
logger.info("Writing dataset to zarr")
if fp_zarr is None:
fp_zarr = fp_config.parent / fp_config.name.replace(".yaml", ".zarr")
else:
fp_zarr = Path(fp_zarr)
if fp_zarr.exists():
logger.info(f"Removing existing dataset at {fp_zarr}")
shutil.rmtree(fp_zarr)
# use zstd compression since it has a good balance of speed and compression ratio
# https://engineering.fb.com/2016/08/31/core-infra/smaller-and-faster-data-compression-with-zstandard/
compressor = Blosc(cname="zstd", clevel=1, shuffle=Blosc.BITSHUFFLE)
encoding = {v: {"compressor": compressor} for v in ds.data_vars}
ds.to_zarr(fp_zarr, consolidated=True, mode="w", encoding=encoding)
logger.info(f"Wrote training-ready dataset to {fp_zarr}")
logger.info(ds)