From 1f94d0ec1297f08d0c37a152d7d0d9de6c3ce5d2 Mon Sep 17 00:00:00 2001 From: Jonathan Shimwell Date: Fri, 24 Nov 2023 14:16:43 +0000 Subject: [PATCH] squeezing end of tally shape only --- src/openmc_regular_mesh_plotter/core.py | 47 +++++++++++-------------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/src/openmc_regular_mesh_plotter/core.py b/src/openmc_regular_mesh_plotter/core.py index 17c7b42..a19765b 100644 --- a/src/openmc_regular_mesh_plotter/core.py +++ b/src/openmc_regular_mesh_plotter/core.py @@ -21,6 +21,10 @@ _default_outline_kwargs = {"colors": "black", "linestyles": "solid", "linewidths": 1} +def _squeeze_end_of_array(array, dims_required=3): + while len(array.shape) > dims_required: + array = np.squeeze(array, axis=len(array.shape)-1) + return array def plot_mesh_tally( tally: "openmc.Tally", @@ -112,19 +116,29 @@ def plot_mesh_tally( tally_slice = tally.get_slice(scores=[score]) - # if mesh.n_dimension == 3: + basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis] + print('mesh.dimension', mesh.dimension) if 1 in mesh.dimension: index_of_2d = mesh.dimension.index(1) axis_of_2d = {0: "x", 1: "y", 2: "z"}[index_of_2d] + if axis_of_2d in basis: # checks if the axis is being plotted, e.g is 'x' in 'xy' + raise ValueError( + "The selected tally has a mesh that has 1 dimension in the " + f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis " + f"of {basis}." + ) - # todo check if 1 appears twice or three times, raise value error if so + # TODO check if 1 appears twice or three times, raise value error if so - tally_data = tally_slice.get_reshaped_data(expand_dims=True, value=value).squeeze() + tally_data = tally_slice.get_reshaped_data(expand_dims=True, value=value)#.squeeze() - basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis] - if len(tally_data.shape) == 3: + tally_data = _squeeze_end_of_array(tally_data, dims_required=3) + + # if len(tally_data.shape) == 3: + if mesh.n_dimension == 3: if slice_index is None: + # finds the mid index slice_index = int(tally_data.shape[basis_to_index] / 2) if basis == "xz": @@ -137,31 +151,12 @@ def plot_mesh_tally( xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]" else: # basis == 'xy' slice_data = tally_data[:, :, slice_index] + print('shape slice_data', slice_data.shape) data = np.rot90(slice_data, -3) xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]" - # elif mesh.n_dimension == 2: - elif len(tally_data.shape) == 2: - if basis_to_index == index_of_2d: - slice_data = tally_data[:, :] - if basis == "xz": - data = np.flip(np.rot90(slice_data, -1)) - xlabel, ylabel = f"x [{axis_units}]", f"z [{axis_units}]" - elif basis == "yz": - data = np.flip(np.rot90(slice_data, -1)) - xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]" - else: # basis == 'xy' - data = np.rot90(slice_data, -3) - xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]" - - else: - raise ValueError( - "The selected tally has a mesh that has 1 dimension in the " - f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis " - f"of {basis}." - ) else: - raise ValueError("mesh n_dimension") + raise ValueError(f"mesh n_dimension is not 3 or 2 but is {mesh.n_dimension} which is not supported") if volume_normalization: # in a regular mesh all volumes are the same so we just divide by the first