Skip to content

Commit

Permalink
Ran build_cleaner. Also some internal dependency changes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 630191048
  • Loading branch information
langmore authored and Weatherbench2 authors committed May 7, 2024
1 parent 1d3cb4d commit 04ade2e
Show file tree
Hide file tree
Showing 17 changed files with 780 additions and 49 deletions.
44 changes: 43 additions & 1 deletion docs/source/command-line-scripts.md
Original file line number Diff line number Diff line change
Expand Up @@ -429,10 +429,52 @@ _Command options_:
* `--working_chunks`: Spatial chunk sizes to use during time downsampling, e.g., "longitude=10,latitude=10". They may not include "time".
* `--beam_runner`: Beam runner. Use `DirectRunner` for local execution.

## Slice dataset
Slices a Zarr file containing an xarray Dataset, using `.sel` and `.isel`.

```
usage: slice_dataset.py [-h]
[--input_path INPUT_PATH]
[--output_path OUTPUT_PATH]
[--sel SEL]
[--isel ISEL]
[--drop_variables DROP_VARIABLES]
[--keep_variables KEEP_VARIABLES]
[--output_chunks OUTPUT_CHUNKS]
[--runner RUNNER]
```

_Command options_:

* `--input_path`: (required) Input Zarr path
* `--output_path`: (required) Output Zarr path
* `--sel`: Selection criteria, to pass to `xarray.Dataset.sel`. Passed as
key=value pairs, with key = `VARNAME_{start,stop,step}`
* `--isel`: Selection criteria, to pass to `xarray.Dataset.isel`. Passed as
key=value pairs, with key = `VARNAME_{start,stop,step}`
* `--drop_variables`: Comma delimited list of variables to drop. If empty, drop
no variables.
* `--keep_variables`: Comma delimited list of variables to keep. If empty, use
`--drop_variables` to determine which variables to keep.
* `--output_chunks`: Chunk sizes overriding input chunks.
* `--runner`: Beam runner. Use `DirectRunner` for local execution.

*Example*

```bash
python slice_dataset.py -- \
--input_path=gs://weatherbench2/datasets/ens/2018-64x32_equiangular_with_poles_conservative.zarr \
--output_path=PATH \
--sel="prediction_timedelta_stop=15 days,latitude_start=-33.33,latitude_stop=33.33" \
--isel="longitude_start=0,longitude_stop=180,longitude_step=40" \
--keep_variables=geopotential,temperature
```

## Expand climatology

`expand_climatology.py` takes a climatology dataset and expands it into a forecast-like format (`init_time` + `lead_time`). This is not currently used as `evaluation.py` is able to do this on-the-fly, reducing the number of intermediate steps. We still included the script here in case others find it useful.

## Init to valid time conversion

`compute_init_to_valid_time.py` converts a forecasts in init-time convention to valid-time convention. Since currently, we do all evaluation in the init-time format, this script is not used.
`compute_init_to_valid_time.py` converts a forecasts in init-time convention to valid-time convention. Since currently, we do all evaluation in the init-time format, this script is not used.
17 changes: 15 additions & 2 deletions scripts/compute_averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@
None,
help='Beam CombineFn fanout. Might be required for large dataset.',
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)


# pylint: disable=expression-not-assigned
Expand Down Expand Up @@ -120,7 +125,10 @@ def main(argv: list[str]):

with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
chunked = root | xbeam.DatasetToChunks(
source_dataset, source_chunks, split_vars=True
source_dataset,
source_chunks,
split_vars=True,
num_threads=NUM_THREADS.value,
)

if weights is not None:
Expand All @@ -131,7 +139,12 @@ def main(argv: list[str]):
(
chunked
| xbeam.Mean(AVERAGING_DIMS.value, skipna=False, fanout=FANOUT.value)
| xbeam.ChunksToZarr(OUTPUT_PATH.value, template, target_chunks)
| xbeam.ChunksToZarr(
OUTPUT_PATH.value,
template,
target_chunks,
num_threads=NUM_THREADS.value,
)
)


Expand Down
16 changes: 14 additions & 2 deletions scripts/compute_climatology.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@
'precipitation variable. In mm.'
),
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)


class Quantile:
Expand Down Expand Up @@ -330,6 +335,10 @@ def _compute_seeps(kv):
if stat not in ['seeps', 'mean']:
for var in raw_vars:
if stat == 'quantile':
if not quantiles:
raise ValueError(
'Cannot compute stat `quantile` without specifying --quantiles.'
)
quantile_dim = xr.DataArray(
quantiles, name='quantile', dims=['quantile']
)
Expand All @@ -349,7 +358,10 @@ def _compute_seeps(kv):
pcoll = (
root
| xbeam.DatasetToChunks(
obs, input_chunks, split_vars=True, num_threads=16
obs,
input_chunks,
split_vars=True,
num_threads=NUM_THREADS.value,
)
| 'RechunkIn'
>> xbeam.Rechunk( # pytype: disable=wrong-arg-types
Expand Down Expand Up @@ -412,7 +424,7 @@ def _compute_seeps(kv):
OUTPUT_PATH.value,
template=clim_template,
zarr_chunks=output_chunks,
num_threads=16,
num_threads=NUM_THREADS.value,
)
)

Expand Down
17 changes: 15 additions & 2 deletions scripts/compute_derived_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@
MAX_MEM_GB = flags.DEFINE_integer(
'max_mem_gb', 1, help='Max memory for rechunking in GB.'
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)

RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')

Expand Down Expand Up @@ -226,7 +231,12 @@ def _is_not_precip(kv: tuple[xbeam.Key, xr.Dataset]) -> bool:
# so that with and without rechunking can be computed in parallel
pcoll = (
root
| xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=False)
| xbeam.DatasetToChunks(
source_dataset,
source_chunks,
split_vars=False,
num_threads=NUM_THREADS.value,
)
| beam.MapTuple(
lambda k, v: ( # pylint: disable=g-long-lambda
k,
Expand Down Expand Up @@ -274,7 +284,10 @@ def _is_not_precip(kv: tuple[xbeam.Key, xr.Dataset]) -> bool:

# Combined
_ = pcoll | xbeam.ChunksToZarr(
OUTPUT_PATH.value, template, source_chunks, num_threads=16
OUTPUT_PATH.value,
template,
source_chunks,
num_threads=NUM_THREADS.value,
)


Expand Down
19 changes: 17 additions & 2 deletions scripts/compute_ensemble_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@
'2020-12-31',
help='ISO 8601 timestamp (inclusive) at which to stop evaluation',
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)


# pylint: disable=expression-not-assigned
Expand Down Expand Up @@ -88,9 +93,19 @@ def main(argv: list[str]):
with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
(
root
| xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=True)
| xbeam.DatasetToChunks(
source_dataset,
source_chunks,
split_vars=True,
num_threads=NUM_THREADS.value,
)
| xbeam.Mean(REALIZATION_NAME.value, skipna=False)
| xbeam.ChunksToZarr(OUTPUT_PATH.value, template, target_chunks)
| xbeam.ChunksToZarr(
OUTPUT_PATH.value,
template,
target_chunks,
num_threads=NUM_THREADS.value,
)
)


Expand Down
10 changes: 9 additions & 1 deletion scripts/compute_statistical_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
RECHUNK_ITEMSIZE = flags.DEFINE_integer(
'rechunk_itemsize', 4, help='Itemsize for rechunking.'
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)


def moment_reduce(
Expand Down Expand Up @@ -143,7 +148,9 @@ def main(argv: list[str]) -> None:

with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
# Read
pcoll = root | xbeam.DatasetToChunks(obs, input_chunks, split_vars=True)
pcoll = root | xbeam.DatasetToChunks(
obs, input_chunks, split_vars=True, num_threads=NUM_THREADS.value
)

# Branches to compute statistical moments
pcolls = []
Expand Down Expand Up @@ -174,6 +181,7 @@ def main(argv: list[str]) -> None:
OUTPUT_PATH.value,
template=output_template,
zarr_chunks=output_chunks,
num_threads=NUM_THREADS.value,
)
)

Expand Down
17 changes: 15 additions & 2 deletions scripts/compute_zonal_energy_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@
None,
help='Beam CombineFn fanout. Might be required for large dataset.',
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)

RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')

Expand Down Expand Up @@ -196,7 +201,12 @@ def main(argv: list[str]) -> None:
with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
_ = (
root
| xbeam.DatasetToChunks(source_dataset, source_chunks, split_vars=False)
| xbeam.DatasetToChunks(
source_dataset,
source_chunks,
split_vars=False,
num_threads=NUM_THREADS.value,
)
| beam.MapTuple(
lambda k, v: ( # pylint: disable=g-long-lambda
k,
Expand All @@ -207,7 +217,10 @@ def main(argv: list[str]) -> None:
| beam.MapTuple(_strip_offsets)
| xbeam.Mean(AVERAGING_DIMS.value, fanout=FANOUT.value)
| xbeam.ChunksToZarr(
OUTPUT_PATH.value, template, output_chunks, num_threads=16
OUTPUT_PATH.value,
template,
output_chunks,
num_threads=NUM_THREADS.value,
)
)

Expand Down
16 changes: 14 additions & 2 deletions scripts/convert_init_to_valid_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@
INPUT_PATH = flags.DEFINE_string('input_path', None, help='zarr inputs')
OUTPUT_PATH = flags.DEFINE_string('output_path', None, help='zarr outputs')
RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)

TIME = 'time'
DELTA = 'prediction_timedelta'
Expand Down Expand Up @@ -254,7 +259,9 @@ def main(argv: list[str]) -> None:
source_ds.indexes[INIT],
)
)
p |= xarray_beam.DatasetToChunks(source_ds, input_chunks, split_vars=True)
p |= xarray_beam.DatasetToChunks(
source_ds, input_chunks, split_vars=True, num_threads=NUM_THREADS.value
)
if input_chunks != split_chunks:
p |= xarray_beam.SplitChunks(split_chunks)
p |= beam.FlatMapTuple(
Expand All @@ -266,7 +273,12 @@ def main(argv: list[str]) -> None:
p = (p, padding) | beam.Flatten()
if input_chunks != split_chunks:
p |= xarray_beam.ConsolidateChunks(output_chunks)
p |= xarray_beam.ChunksToZarr(OUTPUT_PATH.value, template, output_chunks)
p |= xarray_beam.ChunksToZarr(
OUTPUT_PATH.value,
template,
output_chunks,
num_threads=NUM_THREADS.value,
)


if __name__ == '__main__':
Expand Down
6 changes: 6 additions & 0 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@
None,
help='Beam CombineFn fanout. Might be required for large dataset.',
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write Zarr in parallel per worker.',
)


def _wind_vector_error(err_type: str):
Expand Down Expand Up @@ -623,6 +628,7 @@ def main(argv: list[str]) -> None:
runner=RUNNER.value,
input_chunks=INPUT_CHUNKS.value,
fanout=FANOUT.value,
num_threads=NUM_THREADS.value,
argv=argv,
)
else:
Expand Down
10 changes: 9 additions & 1 deletion scripts/expand_climatology.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@
None,
help='Desired integer chunk size. If not set, inferred from input chunks.',
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)
RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')


Expand Down Expand Up @@ -149,7 +154,10 @@ def main(argv: list[str]) -> None:
| beam.Reshuffle()
| beam.FlatMap(select_climatology, climatology, times, base_chunks)
| xbeam.ChunksToZarr(
OUTPUT_PATH.value, template=template, zarr_chunks=output_chunks
OUTPUT_PATH.value,
template=template,
zarr_chunks=output_chunks,
num_threads=NUM_THREADS.value,
)
)

Expand Down
19 changes: 17 additions & 2 deletions scripts/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@
LONGITUDE_NAME = flags.DEFINE_string(
'longitude_name', 'longitude', help='Name of longitude dimension in dataset'
)
NUM_THREADS = flags.DEFINE_integer(
'num_threads',
None,
help='Number of chunks to read/write in parallel per worker.',
)
RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')


Expand Down Expand Up @@ -135,11 +140,21 @@ def main(argv):
with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
_ = (
root
| xarray_beam.DatasetToChunks(source_ds, input_chunks, split_vars=True)
| xarray_beam.DatasetToChunks(
source_ds,
input_chunks,
split_vars=True,
num_threads=NUM_THREADS.value,
)
| 'Regrid'
>> beam.MapTuple(lambda k, v: (k, regridder.regrid_dataset(v)))
| xarray_beam.ConsolidateChunks(output_chunks)
| xarray_beam.ChunksToZarr(OUTPUT_PATH.value, template, output_chunks)
| xarray_beam.ChunksToZarr(
OUTPUT_PATH.value,
template,
output_chunks,
num_threads=NUM_THREADS.value,
)
)


Expand Down
Loading

0 comments on commit 04ade2e

Please sign in to comment.