Skip to content

Commit

Permalink
#1413 n_rows option
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Mar 5, 2021
1 parent df4c398 commit 4eae4b4
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 34 deletions.
12 changes: 4 additions & 8 deletions examples/scripts/SPMe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@

# plot
plot = pybamm.QuickPlot(
[solution] * 2,
solution,
[
# "Negative particle concentration [mol.m-3]",
"Negative particle concentration [mol.m-3]",
"Electrolyte concentration [mol.m-3]",
# "Positive particle concentration [mol.m-3]",
"Positive particle concentration [mol.m-3]",
"Current [A]",
"Negative electrode potential [V]",
"Electrolyte potential [V]",
Expand All @@ -45,9 +45,5 @@
],
time_unit="seconds",
spatial_unit="um",
variable_limits="tight",
)
plot.plot(0) # dynamic_plot()
import matplotlib.pyplot as plt

plt.show()
plot.dynamic_plot()
62 changes: 36 additions & 26 deletions pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class QuickPlot(object):
The linestyles to loop over when plotting. Defaults to ["-", ":", "--", "-."]
figsize : tuple of floats, optional
The size of the figure to make
n_rows : int, optional
The number of rows to use. If None (default), floor(n // sqrt(n)) is used where
n = len(output_variables) so that the plot is as square as possible
time_unit : str, optional
Format for the time output ("hours", "minutes", or "seconds")
spatial_unit : str, optional
Expand All @@ -98,6 +101,7 @@ def __init__(
colors=None,
linestyles=None,
figsize=None,
n_rows=None,
time_unit=None,
spatial_unit="um",
variable_limits="fixed",
Expand Down Expand Up @@ -137,7 +141,38 @@ def __init__(
else:
self.colors = LoopList(colors)
self.linestyles = LoopList(linestyles or ["-", ":", "--", "-."])
self.figsize = figsize or (15, 8)

# Default output variables for lead-acid and lithium-ion
if output_variables is None:
if isinstance(models[0], pybamm.lithium_ion.BaseModel):
output_variables = [
"Negative particle surface concentration [mol.m-3]",
"Electrolyte concentration [mol.m-3]",
"Positive particle surface concentration [mol.m-3]",
"Current [A]",
"Negative electrode potential [V]",
"Electrolyte potential [V]",
"Positive electrode potential [V]",
"Terminal voltage [V]",
]
elif isinstance(models[0], pybamm.lead_acid.BaseModel):
output_variables = [
"Interfacial current density [A.m-2]",
"Electrolyte concentration [mol.m-3]",
"Current [A]",
"Porosity",
"Electrolyte potential [V]",
"Terminal voltage [V]",
]

self.n_rows = n_rows or int(
len(output_variables) // np.sqrt(len(output_variables))
)
self.n_cols = int(np.ceil(len(output_variables) / self.n_rows))

figwidth_default = min(15, 4 * self.n_cols)
figheight_default = min(8, 1 + 3 * self.n_rows)
self.figsize = figsize or (figwidth_default, figheight_default)

# Spatial scales (default to 1 if information not in model)
if spatial_unit == "m":
Expand Down Expand Up @@ -183,29 +218,6 @@ def __init__(
self.min_t = min_t / time_scaling_factor
self.max_t = max_t / time_scaling_factor

# Default output variables for lead-acid and lithium-ion
if output_variables is None:
if isinstance(models[0], pybamm.lithium_ion.BaseModel):
output_variables = [
"Negative particle surface concentration [mol.m-3]",
"Electrolyte concentration [mol.m-3]",
"Positive particle surface concentration [mol.m-3]",
"Current [A]",
"Negative electrode potential [V]",
"Electrolyte potential [V]",
"Positive electrode potential [V]",
"Terminal voltage [V]",
]
elif isinstance(models[0], pybamm.lead_acid.BaseModel):
output_variables = [
"Interfacial current density [A.m-2]",
"Electrolyte concentration [mol.m-3]",
"Current [A]",
"Porosity",
"Electrolyte potential [V]",
"Terminal voltage [V]",
]

# Prepare dictionary of variables
# output_variables is a list of strings or lists, e.g.
# ["var 1", ["variable 2", "var 3"]]
Expand Down Expand Up @@ -254,8 +266,6 @@ def set_output_variables(self, output_variables, solutions):

# Calculate subplot positions based on number of variables supplied
self.subplot_positions = {}
self.n_rows = int(len(output_variables) // np.sqrt(len(output_variables)))
self.n_cols = int(np.ceil(len(output_variables) / self.n_rows))

for k, variable_tuple in enumerate(output_variables):
# Prepare list of variables
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_plotting/test_quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,14 @@ def test_simple_ode_model(self):
linestyles=["-", "--"],
figsize=(1, 2),
labels=["sol 1", "sol 2"],
n_rows=2,
)
self.assertEqual(quick_plot.colors, ["r", "g", "b"])
self.assertEqual(quick_plot.linestyles, ["-", "--"])
self.assertEqual(quick_plot.figsize, (1, 2))
self.assertEqual(quick_plot.labels, ["sol 1", "sol 2"])
self.assertEqual(quick_plot.n_rows, 2)
self.assertEqual(quick_plot.n_cols, 1)

# Test different time units
quick_plot = pybamm.QuickPlot(solution, ["a"])
Expand Down

0 comments on commit 4eae4b4

Please sign in to comment.