From e71c067a1d601164a19f372dfa937b7178736d75 Mon Sep 17 00:00:00 2001 From: NicolaCourtier <45851982+NicolaCourtier@users.noreply.github.com> Date: Thu, 31 Oct 2024 16:35:24 +0000 Subject: [PATCH] Define StandardPlot add_traces and parse_data (#545) * Define add_traces and parse_data * Update pybop/plot/standard_plots.py Co-authored-by: Brady Planden <55357039+BradyPlanden@users.noreply.github.com> * Remove self.x and self.y * Update CHANGELOG.md * Fix if condition --------- Co-authored-by: Brady Planden <55357039+BradyPlanden@users.noreply.github.com> --- CHANGELOG.md | 1 + pybop/plot/standard_plots.py | 147 ++++++++++++++++++++--------------- 2 files changed, 87 insertions(+), 61 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f6acea1fc..7df415309 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Features +- [#544](https://github.com/pybop-team/PyBOP/issues/544) - Allows iterative plotting using `StandardPlot`. - [#541](https://github.com/pybop-team/PyBOP/pull/541) - Adds `ScaledLogLikelihood` and `BaseMetaLikelihood` classes. - [#409](https://github.com/pybop-team/PyBOP/pull/409) - Adds plotting and convergence methods for Monte Carlo sampling. Includes open-access Tesla 4680 dataset for Bayesian inference example. Fixes transformations for sampling. - [#531](https://github.com/pybop-team/PyBOP/pull/531) - Adds Voronoi optimiser surface plot (`pybop.plot.surface`) for fast optimiser aligned cost visualisation. diff --git a/pybop/plot/standard_plots.py b/pybop/plot/standard_plots.py index 35409b640..f5e5e3088 100644 --- a/pybop/plot/standard_plots.py +++ b/pybop/plot/standard_plots.py @@ -43,9 +43,9 @@ class StandardPlot: Parameters ---------- - x : list or np.ndarray + x : list or np.ndarray, optional X-axis data points. - y : list or np.ndarray + y : list or np.ndarray, optional Primary Y-axis data points for simulated model output. layout : Plotly layout, optional A layout for the figure, overrides the layout options (default: None). @@ -66,16 +66,15 @@ class StandardPlot: def __init__( self, - x, - y, + x=None, + y=None, layout=None, layout_options=None, trace_options=None, trace_names=None, trace_name_width=40, ): - self.x = x - self.y = y + self.traces = [] self.layout = layout self.trace_name_width = trace_name_width @@ -90,44 +89,6 @@ def __init__( if trace_options: self.trace_options.update(trace_options) - # Check trace_names and set attribute - if isinstance(trace_names, str): - self.trace_names = [trace_names] - else: - self.trace_names = trace_names - - # Check type and dimensions of data - # What we want is a list of 'things plotly can take', e.g. numpy arrays or lists of numbers - if isinstance(self.x, list): - # If it's a list of numpy arrays, it's fine - # If it's a list of lists, it's fine - # If it's neither, it's a list of numbers that we need to wrap - if not isinstance(self.x[0], np.ndarray) and not isinstance( - self.x[0], list - ): - self.x = [self.x] - elif isinstance(self.x, np.ndarray): - self.x = np.squeeze(self.x) - if self.x.ndim == 1: - self.x = [self.x] - else: - self.x = self.x.tolist() - if isinstance(self.y, list): - if not isinstance(self.y[0], np.ndarray) and not isinstance( - self.y[0], list - ): - self.y = [self.y] - if isinstance(self.y, np.ndarray): - self.y = np.squeeze(self.y) - if self.y.ndim == 1: - self.y = [self.y] - else: - self.y = self.y.tolist() - if len(self.x) > 1 and len(self.x) != len(self.y): - raise ValueError( - "Input x should have either one data series or the same number as y." - ) - # Attempt to import plotly when an instance is created self.go = PlotlyManager().go @@ -135,23 +96,9 @@ def __init__( if self.layout is None: self.layout = self.go.Layout(**self.layout_options) - # Wrap trace names - if self.trace_names is not None: - for i, name in enumerate(self.trace_names): - self.trace_names[i] = self.wrap_text(name, width=self.trace_name_width) - - # Create a trace for each trajectory - self.traces = [] - x = self.x[0] - for i in range(0, len(self.y)): - if len(self.x) > 1: - x = self.x[i] - if self.trace_names is not None: - self.trace_options["name"] = self.trace_names[i] - else: - self.trace_options["showlegend"] = False - trace = self.create_trace(x, self.y[i], **self.trace_options) - self.traces.append(trace) + # Add traces + if x is not None and y is not None: + self.add_traces(x, y, trace_names) def __call__(self, show=True): """ @@ -168,6 +115,84 @@ def __call__(self, show=True): return fig + def add_traces(self, x, y, trace_names=None, **trace_options): + """ + Add a set of traces to the plot dictionary. + + Parameters + ---------- + x : list or np.ndarray + X-axis data points. + y : list or np.ndarray + Primary Y-axis data points for simulated model output. + trace_names : str or list[str], optional + Name(s) for the primary trace(s) (default: None). + """ + options = self.trace_options.copy() + options.update(trace_options) + + # Check and wrap trace names + if trace_names is not None: + if isinstance(trace_names, str): + trace_names = [trace_names] + for i, name in enumerate(trace_names): + trace_names[i] = self.wrap_text(name, width=self.trace_name_width) + + # Parse the data + x, y = self.parse_data(x, y) + + # Create a trace for each trajectory + xi = x[0] + for i in range(0, len(y)): + trace_options = options.copy() + if len(x) > 1: + xi = x[i] + if trace_names is not None: + trace_options["name"] = trace_names[i] + else: + trace_options["showlegend"] = False + trace = self.create_trace(xi, y[i], **trace_options) + self.traces.append(trace) + + def parse_data(self, x, y): + """ + Check the type and dimensions of the data and convert if necessary to a list + of 'things plotly can take', e.g. numpy arrays or lists of numbers. + + Parameters + ---------- + x : list or np.ndarray, optional + X-axis data points. + y : list or np.ndarray, optional + Primary Y-axis data points for simulated model output. + """ + if isinstance(x, list): + # If it's a list of numpy arrays, it's fine + # If it's a list of lists, it's fine + # If it's neither, it's a list of numbers that we need to wrap + if not isinstance(x[0], np.ndarray) and not isinstance(x[0], list): + x = [x] + elif isinstance(x, np.ndarray): + x = np.squeeze(x) + if x.ndim == 1: + x = [x] + else: + x = x.tolist() + if isinstance(y, list): + if not isinstance(y[0], np.ndarray) and not isinstance(y[0], list): + y = [y] + if isinstance(y, np.ndarray): + y = np.squeeze(y) + if y.ndim == 1: + y = [y] + else: + y = y.tolist() + if len(x) > 1 and len(x) != len(y): + raise ValueError( + "Input x should have either one data series or the same number as y." + ) + return x, y + def create_trace(self, x, y, **trace_options): """ Create a trace for the Plotly figure.