Skip to content

Commit

Permalink
benchmark_mass_cons_table.py now loops over individual files
Browse files Browse the repository at this point in the history
gcpy/benchmark/modules/benchmark_mass_cons_table.py
- Refactored code so that we read one file at a time in order to
  avoid memory issues when reading large files (e.g. c180 resolution).
- Delete objects at the end of the loop over times to force
  garbage collection.
- Remove references to dask_config, it's not needed.

Signed-off-by: Bob Yantosca <[email protected]>
  • Loading branch information
yantosca committed Apr 23, 2024
1 parent df1e984 commit af8ef64
Showing 1 changed file with 40 additions and 28 deletions.
68 changes: 40 additions & 28 deletions gcpy/benchmark/modules/benchmark_mass_cons_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import warnings
import numpy as np
from dask import config as dask_config
import xarray as xr
from gcpy.constants import skip_these_vars
from gcpy.units import convert_units
Expand Down Expand Up @@ -122,7 +121,6 @@ def get_passive_tracer_varname(


def compute_total_mass(
t_idx,
dset,
area,
delta_p,
Expand All @@ -142,9 +140,7 @@ def compute_total_mass(
total_mass : np.float64 : Total mass [Tg] of species.
"""
# Keep xarray attributes and allow large chunks in Dask slicing
with xr.set_options(keep_attrs=True) and dask_config.set({
"array.slicing.split_large_chunks": False
}):
with xr.set_options(keep_attrs=True):

# Local variables
units = TARGET_UNITS
Expand All @@ -156,12 +152,12 @@ def compute_total_mass(

# Compute mass in Tg
darr = convert_units(
dset[varname].astype(np.float64).isel(time=t_idx),
dset[varname].astype(np.float64),
varname,
metadata,
units,
area_m2=area,
delta_p=delta_p.isel(time=t_idx),
delta_p=delta_p,
)

return np.sum(darr)
Expand Down Expand Up @@ -327,28 +323,20 @@ def make_benchmark_mass_conservation_table(
# Replace whitespace with underscores in version labels
ref_label = replace_whitespace(ref_label)
dev_label = replace_whitespace(dev_label)

# Preserve xarray attributes
with xr.set_options(keep_attrs=True) and dask_config.set({
"array.slicing.split_large_chunks": False
}):

with xr.set_options(keep_attrs=True):

# ==============================================================
# Read data and make sure time dimensions are consistent
# Make sure Ref and Dev have consistent time dimensions
# ==============================================================
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=xr.SerializationWarning)

# Pick the proper function to read the data
reader = dataset_reader(multi_files=True, verbose=False)

# Get data
ref_data = reader(ref_files, drop_variables=skip_these_vars).load()
dev_data = reader(dev_files, drop_variables=skip_these_vars).load()
ref_area = get_area(ref_areapath, ref_data)
dev_area = get_area(dev_areapath, dev_data)
ref_delta_prs = get_delta_pressure(ref_data)
dev_delta_prs = get_delta_pressure(dev_data)
# Make sure Ref & Dev have the same number of elements
if len(ref_files) != len(dev_files):
msg = "Ref and Dev have different time dimensions!"
raise ValueError(msg)

# Get datetime values
ref_time = get_datetimes_from_filenames(ref_files)
Expand All @@ -366,31 +354,55 @@ def make_benchmark_mass_conservation_table(
# List for holding the datetimes
display_dates = []

# ==================================================================
# Calculate global mass for the tracer at all restart dates
# ==================================================================
# Pick the proper function to read the data
reader = dataset_reader(multi_files=False, verbose=False)

# ==============================================================
# Read data and make sure time dimensions are consistent
# Loop over files individually to avoid memory issues
# ==============================================================
for t_idx, time in enumerate(dev_time):

# Get data
ref_data = reader(
ref_files[t_idx],
drop_variables=skip_these_vars
).load()
dev_data = reader(
dev_files[t_idx],
drop_variables=skip_these_vars
).load()
ref_area = get_area(ref_areapath, ref_data)
dev_area = get_area(dev_areapath, dev_data)
ref_delta_prs = get_delta_pressure(ref_data)
dev_delta_prs = get_delta_pressure(dev_data)

# Save datetime string into display_dates list
time = str(np.datetime_as_string(time, unit="m"))
display_dates.append(time.replace("T", " "))

# Compute total masses [Tg] for Ref & Dev
ref_masses[t_idx] = compute_total_mass(
t_idx,
ref_data,
ref_area,
ref_delta_prs,
metadata,
)
dev_masses[t_idx] = compute_total_mass(
t_idx,
dev_data,
dev_area,
dev_delta_prs,
metadata,
)

# Free memory in large objects
del ref_data
del dev_data
del ref_area
del dev_area
del ref_delta_prs
del dev_delta_prs

# ==================================================================
# Print masses and statistics to file
# ==================================================================
Expand Down

0 comments on commit af8ef64

Please sign in to comment.