Skip to content

Commit

Permalink
Merge pull request #24 from ecmwf-projects/cuon-speed-up
Browse files Browse the repository at this point in the history
Cuon speed up
  • Loading branch information
garciampred authored Jul 24, 2024
2 parents eb85e08 + 22a746c commit caa3bf6
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 34 deletions.
67 changes: 38 additions & 29 deletions cdsobs/ingestion/readers/cuon.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import statistics
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable
from typing import Iterable, List

import cftime
import dask
Expand Down Expand Up @@ -213,6 +213,23 @@ def read_table_data(
def filter_batch_stations(
files: Iterable[Path], time_space_batch: TimeSpaceBatch
) -> list[Path]:
station_metadata = get_cuon_stations()
selected_end, selected_start = _get_times_in_seconds_from(
time_space_batch.time_batch
)
lon_start, lon_end, lat_start, lat_end = time_space_batch.get_spatial_coverage()
lon_mask = between(station_metadata.lon, lon_start, lon_end)
lat_mask = between(station_metadata.lat, lat_start, lat_end)
time_mask = numpy.logical_and(
station_metadata["start of records"] <= selected_end,
station_metadata["end of records"] >= selected_start,
)
mask = lon_mask * lat_mask * time_mask
batch_stations = station_metadata.loc[mask].index
return [f for f in files if f.name.split("_")[0] in batch_stations]


def get_cuon_stations():
# Read file with CUON stations locations
columns = [
"start of records",
Expand All @@ -228,19 +245,7 @@ def filter_batch_stations(
)
station_metadata = pandas.read_json(cuon_stations_file, orient="index")
station_metadata.columns = columns
selected_end, selected_start = _get_times_in_seconds_from(
time_space_batch.time_batch
)
lon_start, lon_end, lat_start, lat_end = time_space_batch.get_spatial_coverage()
lon_mask = between(station_metadata.lon, lon_start, lon_end)
lat_mask = between(station_metadata.lat, lat_start, lat_end)
time_mask = numpy.logical_and(
station_metadata["start of records"] <= selected_end,
station_metadata["end of records"] >= selected_start,
)
mask = lon_mask * lat_mask * time_mask
batch_stations = station_metadata.loc[mask].index
return [f for f in files if f.name.split("_")[0] in batch_stations]
return station_metadata


def read_cuon_netcdfs(
Expand All @@ -262,12 +267,7 @@ def read_cuon_netcdfs(
cdm_tables = read_cdm_tables(config.cdm_tables_location, tables_to_use)
files_and_slices = read_all_nc_slices(files, time_space_batch.time_batch)
denormalized_tables_futures = []
if os.environ.get("CADSOBS_AVOID_MULTIPROCESS"):
# This is for the tests.
scheduler = "synchronous"
else:
# Do not use threads as HDF5 is not yet thread safe.
scheduler = "processes"
scheduler = get_scheduler()
# Check for emptiness
if len(files_and_slices) == 0:
raise EmptyBatchException
Expand All @@ -289,6 +289,16 @@ def read_cuon_netcdfs(
return pandas.concat(denormalized_tables)


def get_scheduler():
if os.environ.get("CADSOBS_AVOID_MULTIPROCESS"):
# This is for the tests.
scheduler = "synchronous"
else:
# Do not use threads as HDF5 is not yet thread safe.
scheduler = "processes"
return scheduler


def _get_denormalized_table_file(*args):
try:
return get_denormalized_table_file(*args)
Expand Down Expand Up @@ -329,7 +339,7 @@ def get_denormalized_table_file(
spatial_mask = lon_mask * lat_mask
if spatial_mask.sum() < len(spatial_mask):
logger.info(
f"Stations have been found outside the SpatialBatch ranges for {file_and_slices.path}, "
f"Records have been found outside the SpatialBatch ranges for {file_and_slices.path}, "
"filtering out."
)
dataset_cdm["header_table"] = dataset_cdm["header_table"].loc[spatial_mask]
Expand Down Expand Up @@ -487,17 +497,16 @@ def read_nc_file_slices(
return result


def read_all_nc_slices(
files: Iterable, time_batch: TimeBatch
) -> list[CUONFileandSlices]:
def read_all_nc_slices(files: List, time_batch: TimeBatch) -> list[CUONFileandSlices]:
"""Read variable slices of all station files using h5py."""
tocs = []

for file in files:
logger.info(f"Reading slices from {file=}")
toc = read_nc_file_slices(Path(file), time_batch)
if toc is not None:
tocs.append(toc)
else:
logger.warning("")
toc = dask.delayed(read_nc_file_slices)(Path(file), time_batch)
tocs.append(toc)

scheduler = get_scheduler()
tocs = dask.compute(*tocs, scheduler=scheduler, num_workers=min(len(files), 32))
tocs = [t for t in tocs if t is not None]
return tocs
2 changes: 1 addition & 1 deletion cdsobs/sanity_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def check_retrieved_dataset(
if not times_are_in_bounds:
logger.warning("report_timestamp file has dates outside the expected interval")
if times_index.isnull().any():
logger.warning("Null values foung in report_timestamp")
logger.warning("Null values found in report_timestamp")
# Check observed_variables
observed_variables = output_dataset.observed_variable
if observed_variables.dtype.kind != "S":
Expand Down
4 changes: 2 additions & 2 deletions cdsobs/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from pydantic import AfterValidator
from pydantic_core import PydanticCustomError

LonTileSize = Literal[360, 180, 90, 45, 30, 20, 15, 10, 5]
LatTileSize = Literal[180, 90, 45, 30, 20, 15, 10, 5]
LonTileSize = Literal[360, 180, 90, 45, 30, 20, 15, 10, 5, 3, 2, 1]
LatTileSize = Literal[180, 90, 45, 30, 20, 15, 10, 5, 3, 2, 1]
TimeTileSize = Literal["month", "year"]
ByteSize = Annotated[int, pydantic.Field(gt=0)]
StrNotBlank = Annotated[str, pydantic.Field(min_length=1)]
Expand Down
4 changes: 2 additions & 2 deletions cdsobs/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from cdsobs.utils.types import ByteSize


def compute_hash(ipath: Path, hash_function=hashlib.sha256, block_size=1048576):
"""Compute a hash in a memory efficient way using 1Mb blocks."""
def compute_hash(ipath: Path, hash_function=hashlib.sha256, block_size=10048576):
"""Compute a hash in a memory efficient way using 10Mb blocks."""
with ipath.open("rb") as f:
file_hash = hash_function()
while True:
Expand Down
79 changes: 79 additions & 0 deletions tests/scripts/cuon_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from pathlib import Path

import cftime
import h5netcdf
import numpy
import pandas
from matplotlib import pyplot

from cdsobs.constants import TIME_UNITS
from cdsobs.ingestion.readers.cuon import _maybe_concat_chars, get_cuon_stations
from cdsobs.utils.logutils import get_logger

logger = get_logger(__name__)


def plot_station_number():
stations = get_cuon_stations()
stations["start of records"] = cftime.num2date(
stations["start of records"], units=TIME_UNITS
)
stations["end of records"] = cftime.num2date(
stations["end of records"], units=TIME_UNITS
)
print(stations.head().to_string())
months = pandas.date_range("1905-01-01", "2023-12-31", freq="MS")
station_number = pandas.Series(index=months)
for monthdate in months:
station_record_has_started = stations["start of records"] <= monthdate
station_record_has_not_ended = stations["end of records"] >= monthdate
station_number.loc[monthdate] = numpy.logical_and(
station_record_has_started, station_record_has_not_ended
).sum()

station_number.plot()
pyplot.show()


def get_char_var_data(inc_group, variable):
return _maybe_concat_chars(inc_group[variable][:])


def check_primary_keys_consistency():
idir = Path("/data/public/converted_v19")
for ipath in idir.glob("*.nc"):
logger.info(f"Checking {ipath}")
try:
check_primary_keys_consistency_file(ipath)
except Exception as e:
logger.info(f"Exception captured for {ipath}: {e}")


def check_primary_keys_consistency_file(ipath):
with h5netcdf.File(ipath) as inc:
station_table = inc.groups["station_configuration"]
header_table = inc.groups["header_table"]
observations_table = inc.groups["observations_table"]
station_ids_station_table = get_char_var_data(station_table, "primary_id")
station_ids_header_table = get_char_var_data(header_table, "primary_station_id")
ids_ok = set(station_ids_station_table) == set(station_ids_header_table)
if not ids_ok:
logger.warning(f"Station ids wrong for {ipath}")
records_ok = set(station_table["record_number"][:]) == set(
header_table["station_record_number"][:]
)
if not records_ok:
logger.warning(f"Station record number wrong for {ipath}")
report_ids_observations_table = get_char_var_data(
observations_table, "report_id"
)
report_ids_header_table = get_char_var_data(header_table, "report_id")
report_ids_ok = set(report_ids_observations_table) == set(
report_ids_header_table
)
if not report_ids_ok:
logger.warning(f"Station record ids wrong for {ipath}")


if __name__ == "__main__":
check_primary_keys_consistency()

0 comments on commit caa3bf6

Please sign in to comment.