Skip to content

Commit

Permalink
Define StandardPlot add_traces and parse_data (#545)
Browse files Browse the repository at this point in the history
* Define add_traces and parse_data

* Update pybop/plot/standard_plots.py

Co-authored-by: Brady Planden <[email protected]>

* Remove self.x and self.y

* Update CHANGELOG.md

* Fix if condition

---------

Co-authored-by: Brady Planden <[email protected]>
  • Loading branch information
NicolaCourtier and BradyPlanden committed Nov 25, 2024
1 parent 9cf3da4 commit e71c067
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 61 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Features

- [#544](https://github.com/pybop-team/PyBOP/issues/544) - Allows iterative plotting using `StandardPlot`.
- [#541](https://github.com/pybop-team/PyBOP/pull/541) - Adds `ScaledLogLikelihood` and `BaseMetaLikelihood` classes.
- [#409](https://github.com/pybop-team/PyBOP/pull/409) - Adds plotting and convergence methods for Monte Carlo sampling. Includes open-access Tesla 4680 dataset for Bayesian inference example. Fixes transformations for sampling.
- [#531](https://github.com/pybop-team/PyBOP/pull/531) - Adds Voronoi optimiser surface plot (`pybop.plot.surface`) for fast optimiser aligned cost visualisation.
Expand Down
147 changes: 86 additions & 61 deletions pybop/plot/standard_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ class StandardPlot:
Parameters
----------
x : list or np.ndarray
x : list or np.ndarray, optional
X-axis data points.
y : list or np.ndarray
y : list or np.ndarray, optional
Primary Y-axis data points for simulated model output.
layout : Plotly layout, optional
A layout for the figure, overrides the layout options (default: None).
Expand All @@ -66,16 +66,15 @@ class StandardPlot:

def __init__(
self,
x,
y,
x=None,
y=None,
layout=None,
layout_options=None,
trace_options=None,
trace_names=None,
trace_name_width=40,
):
self.x = x
self.y = y
self.traces = []
self.layout = layout
self.trace_name_width = trace_name_width

Expand All @@ -90,68 +89,16 @@ def __init__(
if trace_options:
self.trace_options.update(trace_options)

# Check trace_names and set attribute
if isinstance(trace_names, str):
self.trace_names = [trace_names]
else:
self.trace_names = trace_names

# Check type and dimensions of data
# What we want is a list of 'things plotly can take', e.g. numpy arrays or lists of numbers
if isinstance(self.x, list):
# If it's a list of numpy arrays, it's fine
# If it's a list of lists, it's fine
# If it's neither, it's a list of numbers that we need to wrap
if not isinstance(self.x[0], np.ndarray) and not isinstance(
self.x[0], list
):
self.x = [self.x]
elif isinstance(self.x, np.ndarray):
self.x = np.squeeze(self.x)
if self.x.ndim == 1:
self.x = [self.x]
else:
self.x = self.x.tolist()
if isinstance(self.y, list):
if not isinstance(self.y[0], np.ndarray) and not isinstance(
self.y[0], list
):
self.y = [self.y]
if isinstance(self.y, np.ndarray):
self.y = np.squeeze(self.y)
if self.y.ndim == 1:
self.y = [self.y]
else:
self.y = self.y.tolist()
if len(self.x) > 1 and len(self.x) != len(self.y):
raise ValueError(
"Input x should have either one data series or the same number as y."
)

# Attempt to import plotly when an instance is created
self.go = PlotlyManager().go

# Create layout
if self.layout is None:
self.layout = self.go.Layout(**self.layout_options)

# Wrap trace names
if self.trace_names is not None:
for i, name in enumerate(self.trace_names):
self.trace_names[i] = self.wrap_text(name, width=self.trace_name_width)

# Create a trace for each trajectory
self.traces = []
x = self.x[0]
for i in range(0, len(self.y)):
if len(self.x) > 1:
x = self.x[i]
if self.trace_names is not None:
self.trace_options["name"] = self.trace_names[i]
else:
self.trace_options["showlegend"] = False
trace = self.create_trace(x, self.y[i], **self.trace_options)
self.traces.append(trace)
# Add traces
if x is not None and y is not None:
self.add_traces(x, y, trace_names)

def __call__(self, show=True):
"""
Expand All @@ -168,6 +115,84 @@ def __call__(self, show=True):

return fig

def add_traces(self, x, y, trace_names=None, **trace_options):
"""
Add a set of traces to the plot dictionary.
Parameters
----------
x : list or np.ndarray
X-axis data points.
y : list or np.ndarray
Primary Y-axis data points for simulated model output.
trace_names : str or list[str], optional
Name(s) for the primary trace(s) (default: None).
"""
options = self.trace_options.copy()
options.update(trace_options)

# Check and wrap trace names
if trace_names is not None:
if isinstance(trace_names, str):
trace_names = [trace_names]
for i, name in enumerate(trace_names):
trace_names[i] = self.wrap_text(name, width=self.trace_name_width)

# Parse the data
x, y = self.parse_data(x, y)

# Create a trace for each trajectory
xi = x[0]
for i in range(0, len(y)):
trace_options = options.copy()
if len(x) > 1:
xi = x[i]
if trace_names is not None:
trace_options["name"] = trace_names[i]
else:
trace_options["showlegend"] = False
trace = self.create_trace(xi, y[i], **trace_options)
self.traces.append(trace)

def parse_data(self, x, y):
"""
Check the type and dimensions of the data and convert if necessary to a list
of 'things plotly can take', e.g. numpy arrays or lists of numbers.
Parameters
----------
x : list or np.ndarray, optional
X-axis data points.
y : list or np.ndarray, optional
Primary Y-axis data points for simulated model output.
"""
if isinstance(x, list):
# If it's a list of numpy arrays, it's fine
# If it's a list of lists, it's fine
# If it's neither, it's a list of numbers that we need to wrap
if not isinstance(x[0], np.ndarray) and not isinstance(x[0], list):
x = [x]
elif isinstance(x, np.ndarray):
x = np.squeeze(x)
if x.ndim == 1:
x = [x]
else:
x = x.tolist()
if isinstance(y, list):
if not isinstance(y[0], np.ndarray) and not isinstance(y[0], list):
y = [y]
if isinstance(y, np.ndarray):
y = np.squeeze(y)
if y.ndim == 1:
y = [y]
else:
y = y.tolist()
if len(x) > 1 and len(x) != len(y):
raise ValueError(
"Input x should have either one data series or the same number as y."
)
return x, y

def create_trace(self, x, y, **trace_options):
"""
Create a trace for the Plotly figure.
Expand Down

0 comments on commit e71c067

Please sign in to comment.