Skip to content

Commit

Permalink
denormalize each file data in a multiprocess loop
Browse files Browse the repository at this point in the history
  • Loading branch information
garciampred committed Jan 30, 2024
1 parent 64bfeb9 commit 4f0dfab
Showing 1 changed file with 75 additions and 83 deletions.
158 changes: 75 additions & 83 deletions cdsobs/ingestion/readers/cuon.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import calendar
import importlib
import itertools
import os
import statistics
from dataclasses import dataclass
Expand Down Expand Up @@ -201,43 +200,14 @@ def _maybe_swap_bytes(field_data):
return field_data


def read_all_nc_files(
files_and_slices: list[CUONFileandSlices], table_name: str, time_batch: TimeBatch
def read_table_data(
file_and_slices: CUONFileandSlices, table_name: str, time_batch: TimeBatch
) -> pandas.DataFrame:
"""Read nc table of all station files using h5py."""
results = []
if os.environ.get("CADSOBS_AVOID_MULTIPROCESS"):
# This is for the tests.
scheduler = "synchronous"
else:
# Do not use threads as HDF5 is not yet thread safe.
scheduler = "processes"
# Use dask to speed up the process
for file_and_slices in files_and_slices:
logger.info(f"Reading {table_name=} from {file_and_slices.path}")
results.append(
dask.delayed(_read_nc_file)(file_and_slices, table_name, time_batch)
)
results = dask.compute(
*results, scheduler=scheduler, num_workers=min(len(files_and_slices), 32)
)
result = _read_nc_file(file_and_slices, table_name, time_batch)

results = [r for r in results if r is not None]
if len(results) >= 1:
fields = sorted(
set(itertools.chain.from_iterable([list(r.keys()) for r in results]))
)
final_data = {}
for field in fields:
to_concat = []
for r in results:
if field in r:
to_concat.append(r[field])
else:
file_data_len = len(r[list(r)[0]])
to_concat.append(numpy.repeat(numpy.nan, file_data_len))
final_data[field] = numpy.concatenate(to_concat)
final_df_out = pandas.DataFrame(final_data)
if result is not None:
final_df_out = pandas.DataFrame(result)
else:
final_df_out = pandas.DataFrame()
# Reduce field size for memory efficiency
Expand Down Expand Up @@ -295,60 +265,82 @@ def read_cuon_netcdfs(
tables_to_use = config.get_dataset(dataset_name).available_cdm_tables
cdm_tables = read_cdm_tables(config.cdm_tables_location, tables_to_use)
files_and_slices = read_all_nc_slices(files, time_space_batch.time_batch)
denormalized_tables_futures = []
denormalized_tables = []
if os.environ.get("CADSOBS_AVOID_MULTIPROCESS"):
# This is for the tests.
scheduler = "synchronous"
else:
# Do not use threads as HDF5 is not yet thread safe.
scheduler = "processes"
# Use dask to speed up the process
for file_and_slices in files_and_slices:
dataset_cdm: dict[str, pandas.DataFrame] = {}
for table_name, table_definition in cdm_tables.items():
# Fix era5fb having different names in the CDM and in the files
if table_name == "era5fb_table":
table_name_in_file = "era5fb"
else:
table_name_in_file = table_name
# Read table data
table_data = read_all_nc_files(
[file_and_slices], table_name_in_file, time_space_batch.time_batch
)
# Make sure that latitude and longiture always carry on their table name.
table_data = _fix_table_data(
dataset_cdm, table_data, table_definition, table_name
)
dataset_cdm[table_name] = table_data
# Filter stations outside ofthe Batch
lats = dataset_cdm["header_table"]["latitude|header_table"]
lons = dataset_cdm["header_table"]["longitude|header_table"]
lon_start, lon_end, lat_start, lat_end = time_space_batch.get_spatial_coverage()
lon_mask = between(lons, lon_start, lon_end)
lat_mask = between(lats, lat_start, lat_end)
spatial_mask = lon_mask * lat_mask
if spatial_mask.sum() < len(spatial_mask):
logger.info(
"Stations have been found outside the SpatialBatch ranges, "
"filtering out."
)
dataset_cdm["header_table"] = dataset_cdm["header_table"].loc[spatial_mask]
# Denormalize tables
denormalized_table_file = denormalize_tables(
cdm_tables, dataset_cdm, tables_to_use, ignore_errors=True
denormalized_table_future = dask.delayed(get_denormalized_table_file)(
cdm_tables, config, file_and_slices, tables_to_use, time_space_batch
)
# Decode time
if len(denormalized_table_file) > 0:
for time_field in ["record_timestamp", "report_timestamp"]:
denormalized_table_file.loc[:, time_field] = cftime.num2date(
denormalized_table_file.loc[:, time_field],
constants.TIME_UNITS,
only_use_cftime_datetimes=False,
)
else:
logger.warning(f"No data was found in file {file_and_slices.path}")
# Decode variable names
code_dict = get_var_code_dict(config.cdm_tables_location)
denormalized_table_file["observed_variable"] = denormalized_table_file[
"observed_variable"
].map(code_dict)
denormalized_tables.append(denormalized_table_file)
denormalized_tables_futures.append(denormalized_table_future)
denormalized_tables = dask.compute(
*denormalized_tables_futures,
scheduler=scheduler,
num_workers=min(len(files_and_slices), 32),
)
return pandas.concat(denormalized_tables)


def get_denormalized_table_file(
cdm_tables, config, file_and_slices, tables_to_use, time_space_batch
):
dataset_cdm: dict[str, pandas.DataFrame] = {}
for table_name, table_definition in cdm_tables.items():
# Fix era5fb having different names in the CDM and in the files
if table_name == "era5fb_table":
table_name_in_file = "era5fb"
else:
table_name_in_file = table_name
# Read table data
table_data = read_table_data(
file_and_slices, table_name_in_file, time_space_batch.time_batch
)
# Make sure that latitude and longiture always carry on their table name.
table_data = _fix_table_data(
dataset_cdm, table_data, table_definition, table_name
)
dataset_cdm[table_name] = table_data
# Filter stations outside ofthe Batch
lats = dataset_cdm["header_table"]["latitude|header_table"]
lons = dataset_cdm["header_table"]["longitude|header_table"]
lon_start, lon_end, lat_start, lat_end = time_space_batch.get_spatial_coverage()
lon_mask = between(lons, lon_start, lon_end)
lat_mask = between(lats, lat_start, lat_end)
spatial_mask = lon_mask * lat_mask
if spatial_mask.sum() < len(spatial_mask):
logger.info(
"Stations have been found outside the SpatialBatch ranges, "
"filtering out."
)
dataset_cdm["header_table"] = dataset_cdm["header_table"].loc[spatial_mask]
# Denormalize tables
denormalized_table_file = denormalize_tables(
cdm_tables, dataset_cdm, tables_to_use, ignore_errors=True
)
# Decode time
if len(denormalized_table_file) > 0:
for time_field in ["record_timestamp", "report_timestamp"]:
denormalized_table_file.loc[:, time_field] = cftime.num2date(
denormalized_table_file.loc[:, time_field],
constants.TIME_UNITS,
only_use_cftime_datetimes=False,
)
else:
logger.warning(f"No data was found in file {file_and_slices.path}")
# Decode variable names
code_dict = get_var_code_dict(config.cdm_tables_location)
denormalized_table_file["observed_variable"] = denormalized_table_file[
"observed_variable"
].map(code_dict)
return denormalized_table_file


def _fix_table_data(
dataset_cdm: dict[str, pandas.DataFrame],
table_data: pandas.DataFrame,
Expand Down

0 comments on commit 4f0dfab

Please sign in to comment.