diff --git a/fv3net/pipelines/create_training_data/__main__.py b/fv3net/pipelines/create_training_data/__main__.py index 9df839677b..f694fe0d1d 100644 --- a/fv3net/pipelines/create_training_data/__main__.py +++ b/fv3net/pipelines/create_training_data/__main__.py @@ -1,5 +1,7 @@ import argparse -from fv3net.pipelines.create_training_data.pipeline import run +import yaml + +from .pipeline import run if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -21,6 +23,12 @@ help="Write path for train data in Google Cloud Storage bucket. " "Don't include bucket in path.", ) + parser.add_argument( + "variable_namefile", + type=str, + default=None, + help="yaml file for providing data variable names", + ) parser.add_argument( "--timesteps-per-output-file", type=int, @@ -38,11 +46,11 @@ "Output zarr files will be saved in either 'train' or 'test' subdir of " "gcs-output-data-dir", ) - parser.add_argument( - "--var-names-yaml", - type=str, - default=None, - help="optional yaml for providing data variable names", - ) + args, pipeline_args = parser.parse_known_args() - run(args=args, pipeline_args=pipeline_args) + with open(args.variable_namefile, "r") as stream: + try: + names = yaml.safe_load(stream) + except yaml.YAMLError as exc: + raise ValueError(f"Bad yaml config: {exc}") + run(args=args, pipeline_args=pipeline_args, names=names) diff --git a/fv3net/pipelines/create_training_data/helpers.py b/fv3net/pipelines/create_training_data/helpers.py index 0f401c7cc4..9048806812 100644 --- a/fv3net/pipelines/create_training_data/helpers.py +++ b/fv3net/pipelines/create_training_data/helpers.py @@ -53,16 +53,6 @@ def _path_from_first_timestep(ds, train_test_labels=None): return os.path.join(train_test_subdir, timestep + ".zarr") -def _set_relative_forecast_time_coord(ds): - delta_t_forecast = ( - ds[FORECAST_TIME_DIM].values[-1] - ds[FORECAST_TIME_DIM].values[-2] - ) - ds.reset_index([FORECAST_TIME_DIM], drop=True) - return ds.assign_coords( - {FORECAST_TIME_DIM: [timedelta(seconds=0), delta_t_forecast]} - ) - - def load_hires_prog_diag(diag_data_path, init_times): """Loads coarsened diagnostic variables from the prognostic high res run. diff --git a/fv3net/pipelines/create_training_data/pipeline.py b/fv3net/pipelines/create_training_data/pipeline.py index 7943ee6be1..79d0dc43b7 100644 --- a/fv3net/pipelines/create_training_data/pipeline.py +++ b/fv3net/pipelines/create_training_data/pipeline.py @@ -17,20 +17,6 @@ logger.setLevel(logging.INFO) GRID_SPEC_FILENAME = "grid_spec.zarr" -GRID_VARS = [ - "area", - names.var_lat_outer, - names.var_lon_outer, - names.var_lat_center, - names.var_lon_center, -] -_CHUNK_SIZES = { - "tile": 1, - names.init_time_dim: 1, - names.coord_y_center: 24, - names.coord_x_center: 24, - names.coord_z_center: 79, -} # forecast time step used to calculate the FV3 run tendency FORECAST_TIME_INDEX_FOR_C48_TENDENCY = 14 @@ -38,23 +24,29 @@ FORECAST_TIME_INDEX_FOR_HIRES_TENDENCY = FORECAST_TIME_INDEX_FOR_C48_TENDENCY -def run(args, pipeline_args): +def run(args, pipeline_args, names): fs = get_fs(args.gcs_input_data_path) ds_full = xr.open_zarr(fs.get_mapper(args.gcs_input_data_path)) _save_grid_spec( ds_full, args.gcs_output_data_dir, - grid_vars=GRID_VARS, + grid_vars=names["grid_vars"], grid_spec_filename=GRID_SPEC_FILENAME, - init_time_dim=names.init_time_dim, + init_time_dim=names["init_time_dim"], ) data_batches, train_test_labels = _divide_data_batches( ds_full, args.timesteps_per_output_file, args.train_fraction, - init_time_dim=names.init_time_dim, + init_time_dim=names["init_time_dim"], ) - + chunk_sizes = { + "tile": 1, + names["init_time_dim"]: 1, + names["coord_y_center"]: 24, + names["coord_x_center"]: 24, + names["coord_z_center"]: 79, + } logger.info(f"Processing {len(data_batches)} subsets...") beam_options = PipelineOptions(flags=pipeline_args, save_main_session=True) with beam.Pipeline(options=beam_options) as p: @@ -64,11 +56,12 @@ def run(args, pipeline_args): | "CreateTrainingCols" >> beam.Map( _create_train_cols, - init_time_dim=names.init_time_dim, - step_time_dim=names.step_time_dim, - forecast_time_dim=names.forecast_time_dim, - coord_begin_step=names.coord_begin_step, - var_source_name_map=names.var_source_name_map, + cols_to_keep=names["one_step_vars"] + names["target_vars"], + init_time_dim=names["init_time_dim"], + step_time_dim=names["step_time_dim"], + forecast_time_dim=names["forecast_time_dim'"], + coord_begin_step=names["coord_begin_step"], + var_source_name_map=names["var_source_name_map"], forecast_timestep_for_onestep=FORECAST_TIME_INDEX_FOR_C48_TENDENCY, forecast_timestep_for_highres=FORECAST_TIME_INDEX_FOR_HIRES_TENDENCY, ) @@ -77,15 +70,15 @@ def run(args, pipeline_args): _merge_hires_data, diag_c48_path=args.diag_c48_path, coarsened_diags_zarr_name=COARSENED_DIAGS_ZARR_NAME, - renamed_high_res_vars=names.renamed_high_res_vars, - init_time_dim=names.init_time_dim, + renamed_high_res_vars=names["renamed_high_res_vars"], + init_time_dim=names["init_time_dim"], ) | "WriteToZarr" >> beam.Map( _write_remote_train_zarr, gcs_output_dir=args.gcs_output_data_dir, train_test_labels=train_test_labels, - chunk_sizes=_CHUNK_SIZES, + chunk_sizes=chunk_sizes, ) ) @@ -248,7 +241,7 @@ def _create_train_cols( step_time_dim, coord_begin_step, var_source_name_map, - cols_to_keep=names.one_step_vars + names.target_vars, + cols_to_keep=names["one_step_vars"] + names["target_vars"], ): """ Calculate apparent sources for target variables and keep feature vars diff --git a/fv3net/pipelines/create_training_data/variable_names.yml b/fv3net/pipelines/create_training_data/variable_names.yml index 623e797433..dac80e6645 100644 --- a/fv3net/pipelines/create_training_data/variable_names.yml +++ b/fv3net/pipelines/create_training_data/variable_names.yml @@ -1,22 +1,39 @@ +# suffixes that denote whether diagnostic variable is from the coarsened +# high resolution prognostic run or the coarse res one step train data run +suffix_hires: "prog" +suffix_coarse_train: "train" -initial_time_dim: "initial_time" +init_time_dim: "initial_time" forecast_time_dim: "forecast_time" step_time_dim: "step" -end_step_coord: "after_physics" -x_coord: "x" -y_coord: "y" -z_coord: "z" - -grid_dim_renaming: - "grid_xt": "x" - "grid_yt": "y" - "grid_x": "x_interface" - "grid_y": "y_interface" - -x_wind_var: "x_wind" -y_wind_var: "y_wind" -temperature_var: "air_temperature" -specific_humidity_var: "specific_humidity" +coord_begin_step: "begin" + +coord_x_center : "x" +coord_y_center : "y" +coord_z_center : "z" + +var_lon_center: "lon" +var_lat_center: "lat" +var_lon_outer: "lonb" +var_lat_outer: "latb" + +renamed_dims: + grid_xt: "x" + grid_yt: "y" + grid_x: "x_interface" + grid_y: "y_interface" + +grid_vars: + - "area" + - "latb" + - "lonb" + - "lat" + - "lon" + +var_x_wind: "x_wind" +var_y_wind: "y_wind" +var_temp: "air_temperature" +var_sphum: "specific_humidity" radiation_variables: - "DSWRFtoa" @@ -51,4 +68,10 @@ high_res_data_variables: - "ULWRFtoa_coarse" - "ULWRFsfc_coarse" - "SHTFLsfc" - - "LHTFLsfc" \ No newline at end of file + - "LHTFLsfc" + + var_source_name_map: + var_x_wind: "dQU" + var_y_wind: "dQV" + var_temp: "dQ1" + var_sphum: "dQ2" \ No newline at end of file