Skip to content

Commit

Permalink
pass var names from file
Browse files Browse the repository at this point in the history
  • Loading branch information
Anna Kwa committed Apr 1, 2020
1 parent a569561 commit 354d675
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 63 deletions.
24 changes: 16 additions & 8 deletions fv3net/pipelines/create_training_data/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse
from fv3net.pipelines.create_training_data.pipeline import run
import yaml

from .pipeline import run

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -21,6 +23,12 @@
help="Write path for train data in Google Cloud Storage bucket. "
"Don't include bucket in path.",
)
parser.add_argument(
"variable_namefile",
type=str,
default=None,
help="yaml file for providing data variable names",
)
parser.add_argument(
"--timesteps-per-output-file",
type=int,
Expand All @@ -38,11 +46,11 @@
"Output zarr files will be saved in either 'train' or 'test' subdir of "
"gcs-output-data-dir",
)
parser.add_argument(
"--var-names-yaml",
type=str,
default=None,
help="optional yaml for providing data variable names",
)

args, pipeline_args = parser.parse_known_args()
run(args=args, pipeline_args=pipeline_args)
with open(args.variable_namefile, "r") as stream:
try:
names = yaml.safe_load(stream)
except yaml.YAMLError as exc:
raise ValueError(f"Bad yaml config: {exc}")
run(args=args, pipeline_args=pipeline_args, names=names)
10 changes: 0 additions & 10 deletions fv3net/pipelines/create_training_data/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,6 @@ def _path_from_first_timestep(ds, train_test_labels=None):
return os.path.join(train_test_subdir, timestep + ".zarr")


def _set_relative_forecast_time_coord(ds):
delta_t_forecast = (
ds[FORECAST_TIME_DIM].values[-1] - ds[FORECAST_TIME_DIM].values[-2]
)
ds.reset_index([FORECAST_TIME_DIM], drop=True)
return ds.assign_coords(
{FORECAST_TIME_DIM: [timedelta(seconds=0), delta_t_forecast]}
)


def load_hires_prog_diag(diag_data_path, init_times):
"""Loads coarsened diagnostic variables from the prognostic high res run.
Expand Down
49 changes: 21 additions & 28 deletions fv3net/pipelines/create_training_data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,44 +17,36 @@
logger.setLevel(logging.INFO)

GRID_SPEC_FILENAME = "grid_spec.zarr"
GRID_VARS = [
"area",
names.var_lat_outer,
names.var_lon_outer,
names.var_lat_center,
names.var_lon_center,
]
_CHUNK_SIZES = {
"tile": 1,
names.init_time_dim: 1,
names.coord_y_center: 24,
names.coord_x_center: 24,
names.coord_z_center: 79,
}

# forecast time step used to calculate the FV3 run tendency
FORECAST_TIME_INDEX_FOR_C48_TENDENCY = 14
# forecast time step used to calculate the high res tendency
FORECAST_TIME_INDEX_FOR_HIRES_TENDENCY = FORECAST_TIME_INDEX_FOR_C48_TENDENCY


def run(args, pipeline_args):
def run(args, pipeline_args, names):
fs = get_fs(args.gcs_input_data_path)
ds_full = xr.open_zarr(fs.get_mapper(args.gcs_input_data_path))
_save_grid_spec(
ds_full,
args.gcs_output_data_dir,
grid_vars=GRID_VARS,
grid_vars=names["grid_vars"],
grid_spec_filename=GRID_SPEC_FILENAME,
init_time_dim=names.init_time_dim,
init_time_dim=names["init_time_dim"],
)
data_batches, train_test_labels = _divide_data_batches(
ds_full,
args.timesteps_per_output_file,
args.train_fraction,
init_time_dim=names.init_time_dim,
init_time_dim=names["init_time_dim"],
)

chunk_sizes = {
"tile": 1,
names["init_time_dim"]: 1,
names["coord_y_center"]: 24,
names["coord_x_center"]: 24,
names["coord_z_center"]: 79,
}
logger.info(f"Processing {len(data_batches)} subsets...")
beam_options = PipelineOptions(flags=pipeline_args, save_main_session=True)
with beam.Pipeline(options=beam_options) as p:
Expand All @@ -64,11 +56,12 @@ def run(args, pipeline_args):
| "CreateTrainingCols"
>> beam.Map(
_create_train_cols,
init_time_dim=names.init_time_dim,
step_time_dim=names.step_time_dim,
forecast_time_dim=names.forecast_time_dim,
coord_begin_step=names.coord_begin_step,
var_source_name_map=names.var_source_name_map,
cols_to_keep=names["one_step_vars"] + names["target_vars"],
init_time_dim=names["init_time_dim"],
step_time_dim=names["step_time_dim"],
forecast_time_dim=names["forecast_time_dim'"],
coord_begin_step=names["coord_begin_step"],
var_source_name_map=names["var_source_name_map"],
forecast_timestep_for_onestep=FORECAST_TIME_INDEX_FOR_C48_TENDENCY,
forecast_timestep_for_highres=FORECAST_TIME_INDEX_FOR_HIRES_TENDENCY,
)
Expand All @@ -77,15 +70,15 @@ def run(args, pipeline_args):
_merge_hires_data,
diag_c48_path=args.diag_c48_path,
coarsened_diags_zarr_name=COARSENED_DIAGS_ZARR_NAME,
renamed_high_res_vars=names.renamed_high_res_vars,
init_time_dim=names.init_time_dim,
renamed_high_res_vars=names["renamed_high_res_vars"],
init_time_dim=names["init_time_dim"],
)
| "WriteToZarr"
>> beam.Map(
_write_remote_train_zarr,
gcs_output_dir=args.gcs_output_data_dir,
train_test_labels=train_test_labels,
chunk_sizes=_CHUNK_SIZES,
chunk_sizes=chunk_sizes,
)
)

Expand Down Expand Up @@ -248,7 +241,7 @@ def _create_train_cols(
step_time_dim,
coord_begin_step,
var_source_name_map,
cols_to_keep=names.one_step_vars + names.target_vars,
cols_to_keep=names["one_step_vars"] + names["target_vars"],
):
""" Calculate apparent sources for target variables and keep feature vars
Expand Down
57 changes: 40 additions & 17 deletions fv3net/pipelines/create_training_data/variable_names.yml
Original file line number Diff line number Diff line change
@@ -1,22 +1,39 @@
# suffixes that denote whether diagnostic variable is from the coarsened
# high resolution prognostic run or the coarse res one step train data run
suffix_hires: "prog"
suffix_coarse_train: "train"

initial_time_dim: "initial_time"
init_time_dim: "initial_time"
forecast_time_dim: "forecast_time"
step_time_dim: "step"
end_step_coord: "after_physics"
x_coord: "x"
y_coord: "y"
z_coord: "z"

grid_dim_renaming:
"grid_xt": "x"
"grid_yt": "y"
"grid_x": "x_interface"
"grid_y": "y_interface"

x_wind_var: "x_wind"
y_wind_var: "y_wind"
temperature_var: "air_temperature"
specific_humidity_var: "specific_humidity"
coord_begin_step: "begin"

coord_x_center : "x"
coord_y_center : "y"
coord_z_center : "z"

var_lon_center: "lon"
var_lat_center: "lat"
var_lon_outer: "lonb"
var_lat_outer: "latb"

renamed_dims:
grid_xt: "x"
grid_yt: "y"
grid_x: "x_interface"
grid_y: "y_interface"

grid_vars:
- "area"
- "latb"
- "lonb"
- "lat"
- "lon"

var_x_wind: "x_wind"
var_y_wind: "y_wind"
var_temp: "air_temperature"
var_sphum: "specific_humidity"

radiation_variables:
- "DSWRFtoa"
Expand Down Expand Up @@ -51,4 +68,10 @@ high_res_data_variables:
- "ULWRFtoa_coarse"
- "ULWRFsfc_coarse"
- "SHTFLsfc"
- "LHTFLsfc"
- "LHTFLsfc"

var_source_name_map:
var_x_wind: "dQU"
var_y_wind: "dQV"
var_temp: "dQ1"
var_sphum: "dQ2"

0 comments on commit 354d675

Please sign in to comment.