Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serialbox to Netcdf tool] Collapse all rank into 1 if rank have different sized data #82

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ndsl/stencils/testing/savepoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ class SavepointCase:
def __str__(self):
return f"{self.savepoint_name}-rank={self.rank}-call={self.i_call}"

@property
def exists(self) -> bool:
return (
xr.open_dataset(
os.path.join(self.data_dir, f"{self.savepoint_name}-In.nc")
).sizes["rank"]
> self.rank
)

@property
def ds_in(self) -> xr.Dataset:
return (
Expand Down
79 changes: 61 additions & 18 deletions ndsl/stencils/testing/serialbox_to_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_serializer(data_path: str, rank: int, data_name: Optional[str] = None):
name = data_name
else:
name = f"Generator_rank{rank}"
return serialbox.Serializer(serialbox.OpenModeKind.Read, data_path, name)
return serialbox.Serializer(serialbox.OpenModeKind.Read, data_path, name) # type: ignore


def main(
Expand All @@ -81,22 +81,13 @@ def main(
if namelist_filename_out != namelist_filename_in:
shutil.copyfile(os.path.join(data_path, "input.nml"), namelist_filename_out)
namelist = f90nml.read(namelist_filename_out)
if namelist["fv_core_nml"]["grid_type"] <= 3:
total_ranks = (
6
* namelist["fv_core_nml"]["layout"][0]
* namelist["fv_core_nml"]["layout"][1]
)
fv_core_nml: Dict[str, Any] = namelist["fv_core_nml"] # type: ignore
if fv_core_nml["grid_type"] <= 3:
total_ranks = 6 * fv_core_nml["layout"][0] * fv_core_nml["layout"][1]
else:
total_ranks = (
namelist["fv_core_nml"]["layout"][0] * namelist["fv_core_nml"]["layout"][1]
)
nx = int(
(namelist["fv_core_nml"]["npx"] - 1) / (namelist["fv_core_nml"]["layout"][0])
)
ny = int(
(namelist["fv_core_nml"]["npy"] - 1) / (namelist["fv_core_nml"]["layout"][1])
)
total_ranks = fv_core_nml["layout"][0] * fv_core_nml["layout"][1]
nx = int((fv_core_nml["npx"] - 1) / (fv_core_nml["layout"][0]))
ny = int((fv_core_nml["npy"] - 1) / (fv_core_nml["layout"][1]))

# all ranks have the same names, just look at first one
serializer_0 = get_serializer(data_path, rank=0, data_name=data_name)
Expand All @@ -109,6 +100,7 @@ def main(
serializer_0.get_savepoint(savepoint_name)[0]
)
)
print(f"Exporting {savepoint_name}")
serializer_list = []
for rank in range(total_ranks):
serializer = get_serializer(data_path, rank, data_name)
Expand Down Expand Up @@ -149,7 +141,27 @@ def main(
if n_savepoints > 0:
encoding = {}
for varname in set(names_list).difference(["rank"]):
# Check that all ranks have the same size. If not, aggregate and
# feedback on one rank
colapse_all_ranks = False
data_shape = list(rank_list[0][varname][0].shape)
print(f" Exporting {varname} - {data_shape}")
for rank in range(total_ranks):
this_shape = list(rank_list[rank][varname][0].shape)
if data_shape != this_shape:
if len(data_shape) != 1:
raise ValueError(
"Arrays have different dimensions. "
f"E.g. rank 0 is {data_shape} "
f"and rank {rank} is {this_shape} "
)
else:
print(
f"... different shape for {varname} across ranks, collapsing in on rank."
)
colapse_all_ranks = True
break

if savepoint_name in [
"FVDynamics-In",
"FVDynamics-Out",
Expand All @@ -173,22 +185,53 @@ def main(
data_vars[varname] = get_data(
data_shape, total_ranks, n_savepoints, rank_list, varname
)
elif colapse_all_ranks:
data_vars[varname] = get_data_collapse_all_ranks(
total_ranks, n_savepoints, rank_list, varname
)
else:
data_vars[varname] = get_data(
data_shape, total_ranks, n_savepoints, rank_list, varname
)
if len(data_shape) > 2:
encoding[varname] = {"zlib": True, "complevel": 1}

dataset = xr.Dataset(data_vars=data_vars)
dataset.to_netcdf(
os.path.join(output_path, f"{savepoint_name}.nc"), encoding=encoding
)


def get_data_collapse_all_ranks(total_ranks, n_savepoints, output_list, varname):
if total_ranks <= 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<= 0? Can we have negative total_ranks??

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not but I did a bot of "defensive coding" here. If we pass -1 for any reason we don't crash

return xr.DataArray([], dims=[])
# Build array shape - we hypothesis there's only 1 axis
K_shape = 0
for rank in range(total_ranks):
assert len(output_list[rank][varname][0].shape) == 1
K_shape = K_shape + output_list[rank][varname][0].shape[0]

array = np.full(
[n_savepoints, 1] + [K_shape],
fill_value=np.nan,
dtype=output_list[0][varname][0].dtype,
)
data = xr.DataArray(array, dims=["savepoint", "rank", f"dim_{varname}"])
last_size = 0
for rank in range(total_ranks):
for i_savepoint in range(n_savepoints):
rank_data = output_list[rank][varname][i_savepoint]
rank_data_size = rank_data.shape[0]
data[i_savepoint, 0, last_size : last_size + rank_data_size] = rank_data[:]
last_size += rank_data_size

return data


def get_data(data_shape, total_ranks, n_savepoints, output_list, varname):
# Read in dtype
if total_ranks <= 0:
return
return xr.DataArray([], dims=[])
# Read in dtype
varname_dtype = output_list[0][varname][0].dtype
# Build data array
array = np.full(
Expand Down
4 changes: 4 additions & 0 deletions ndsl/stencils/testing/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def test_sequential_savepoint(
)
if case.testobj.skip_test:
return
if not case.exists:
pytest.skip(f"Data at rank {case.rank} does not exists")
input_data = dataset_to_dict(case.ds_in)
input_names = (
case.testobj.serialnames(case.testobj.in_vars["data_vars"])
Expand Down Expand Up @@ -334,6 +336,8 @@ def test_parallel_savepoint(
return
if (grid == "compute") and not case.testobj.compute_grid_option:
pytest.xfail(f"Grid compute option not used for test {case.savepoint_name}")
if not case.exists:
pytest.skip(f"Data at rank {case.rank} does not exists")
input_data = dataset_to_dict(case.ds_in)
# run python version of functionality
output = case.testobj.compute_parallel(input_data, communicator)
Expand Down
Loading