From 9842de192d0f03195361e1a4de466810ecf16297 Mon Sep 17 00:00:00 2001 From: Naoki Kanazawa Date: Tue, 19 Sep 2023 23:55:48 +0900 Subject: [PATCH] Update mpl drawer backend to support unshared axes (#1242) ### Summary Experiment plotter can create a multi-sub canvas (axes) figure, but x and y plot range (xlim/ylim in matplotlib) must be shared among sub axes. This PR relaxes this limitation and allows a developer to write analysis class that can generate figure containing axes with different plot ranges. ### Details and comments Several figure options are updated to take a list of options for each axis. New options ``sharex`` and ``sharey`` are added to unlink the axis settings. These new options are True by default for backward compatibility. --- .../visualization/drawers/base_drawer.py | 43 ++-- .../visualization/drawers/mpl_drawer.py | 211 ++++++++++-------- .../visualization/plotters/base_plotter.py | 71 +++--- ...n-with-unshared-axis-9f7bfe272353086b.yaml | 10 + 4 files changed, 192 insertions(+), 143 deletions(-) create mode 100644 releasenotes/notes/add-support-for-visualization-with-unshared-axis-9f7bfe272353086b.yaml diff --git a/qiskit_experiments/visualization/drawers/base_drawer.py b/qiskit_experiments/visualization/drawers/base_drawer.py index 69f7915e8b..ea63e0afd2 100644 --- a/qiskit_experiments/visualization/drawers/base_drawer.py +++ b/qiskit_experiments/visualization/drawers/base_drawer.py @@ -194,27 +194,38 @@ def _default_figure_options(cls) -> Options: there are multiple columns in the canvas, this could be a list of labels. ylabel (Union[str, List[str]]): Y-axis label string of the output figure. If there are multiple rows in the canvas, this could be a list of labels. - xlim (Tuple[float, float]): Min and max value of the horizontal axis. If not - provided, it is automatically scaled based on the input data points. - ylim (Tuple[float, float]): Min and max value of the vertical axis. If not - provided, it is automatically scaled based on the input data points. - xval_unit (str): Unit of x values. No scaling prefix is needed here as this - is controlled by ``xval_unit_scale``. - yval_unit (str): Unit of y values. See ``xval_unit`` for details. - xval_unit_scale (bool): Whether to add an SI unit prefix to ``xval_unit`` if - needed. For example, when the x values represent time and + xlim (Union[Tuple[float, float], List[Tuple[float, float]]): Min and max value + of the horizontal axis. If not provided, it is automatically scaled based + on the input data points. If there are multiple columns in the canvas, + this could be a list of xlims. + ylim (Union[Tuple[float, float], List[Tuple[float, float]]): Min and max value + of the vertical axis. If not provided, it is automatically scaled based + on the input data points. If there are multiple rows in the canvas, + this could be a list of ylims. + xval_unit (Union[str, List[str]]): Unit of x values. + No scaling prefix is needed here as this is controlled by ``xval_unit_scale``. + If there are multiple columns in the canvas, this could be a list of xval_units. + yval_unit (Union[str, List[str]]): Unit of y values. + No scaling prefix is needed here as this is controlled by ``yval_unit_scale``. + If there are multiple rows in the canvas, this could be a list of yval_units. + xval_unit_scale (Union[bool, List[bool]]): Whether to add an SI unit prefix to + ``xval_unit`` if needed. For example, when the x values represent time and ``xval_unit="s"``, ``xval_unit_scale=True`` adds an SI unit prefix to ``"s"`` based on X values of plotted data. In the output figure, the prefix is automatically selected based on the maximum value in this axis. If your x values are in [1e-3, 1e-4], they are displayed as [1 ms, 10 ms]. By default, this option is set to ``True``. If ``False`` is provided, the axis numbers will be displayed in the scientific notation. - yval_unit_scale (bool): Whether to add an SI unit prefix to ``yval_unit`` if - needed. See ``xval_unit_scale`` for details. + If there are multiple columns in the canvas, this could be a list of xval_unit_scale. + yval_unit_scale (Union[bool, List[bool]]): Whether to add an SI unit prefix to + ``yval_unit`` if needed. See ``xval_unit_scale`` for details. + If there are multiple rows in the canvas, this could be a list of yval_unit_scale. xscale (str): The scaling of the x-axis, such as ``log`` or ``linear``. - yscale (str): See ``xscale`` for details. + yscale (str): The scaling of the y-axis, such as ``log`` or ``linear``. figure_title (str): Title of the figure. Defaults to None, i.e. nothing is shown. + sharex (bool): Set True to share x-axis ticks among sub-plots. + sharey (bool): Set True to share y-axis ticks among sub-plots. series_params (Dict[str, Dict[str, Any]]): A dictionary of parameters for each series. This is keyed on the name for each series. Sub-dictionary is expected to have the following three configurations, "canvas", @@ -226,7 +237,7 @@ def _default_figure_options(cls) -> Options: overwrites style parameters in ``default_style`` in :attr:`options`. Defaults to an empty PlotStyle instance (i.e., ``PlotStyle()``). """ - return Options( + options = Options( xlabel=None, ylabel=None, xlim=None, @@ -237,10 +248,16 @@ def _default_figure_options(cls) -> Options: yval_unit_scale=True, xscale=None, yscale=None, + sharex=True, + sharey=True, figure_title=None, series_params={}, custom_style=PlotStyle(), ) + options.set_validator("xscale", ["linear", "log", "symlog", "logit", "quadratic", None]) + options.set_validator("yscale", ["linear", "log", "symlog", "logit", "quadratic", None]) + + return options def set_options(self, **fields): """Set the drawer options. diff --git a/qiskit_experiments/visualization/drawers/mpl_drawer.py b/qiskit_experiments/visualization/drawers/mpl_drawer.py index 5ea0e4b75c..8ddf696919 100644 --- a/qiskit_experiments/visualization/drawers/mpl_drawer.py +++ b/qiskit_experiments/visualization/drawers/mpl_drawer.py @@ -12,6 +12,7 @@ """Curve drawer for matplotlib backend.""" +import numbers from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np @@ -79,6 +80,9 @@ def initialize_canvas(self): else: axis = self.options.axis + sharex = self.figure_options.sharex + sharey = self.figure_options.sharey + n_rows, n_cols = self.options.subplots n_subplots = n_cols * n_rows if n_subplots > 1: @@ -99,9 +103,9 @@ def initialize_canvas(self): inset_ax_h, ] sub_ax = axis.inset_axes(bounds, transform=axis.transAxes, zorder=1) - if j != 0: + if j != 0 and sharey: # remove y axis except for most-left plot - sub_ax.set_yticklabels([]) + sub_ax.yaxis.set_tick_params(labelleft=False) else: # this axis locates at left, write y-label if self.figure_options.ylabel: @@ -110,9 +114,9 @@ def initialize_canvas(self): # Y label can be given as a list for each sub axis label = label[i] sub_ax.set_ylabel(label, fontsize=self.style["axis_label_size"]) - if i != n_rows - 1: + if i != n_rows - 1 and sharex: # remove x axis except for most-bottom plot - sub_ax.set_xticklabels([]) + sub_ax.xaxis.set_tick_params(labelleft=False) else: # this axis locates at bottom, write x-label if self.figure_options.xlabel: @@ -143,115 +147,126 @@ def format_canvas(self): else: all_axes = [self._axis] - # Add data labels if there are multiple labels registered per sub_ax. - for sub_ax in all_axes: + # Get axis formatter from drawing options + formatter_opts = {} + for ax_type in ("x", "y"): + limit = self.figure_options.get(f"{ax_type}lim") + unit = self.figure_options.get(f"{ax_type}val_unit") + unit_scale = self.figure_options.get(f"{ax_type}val_unit_scale") + + # Format options to a list for each axis + if limit is None or isinstance(limit[0], numbers.Number): + limit = [limit] * len(all_axes) + if unit is None or isinstance(unit, str): + unit = [unit] * len(all_axes) + if isinstance(unit_scale, bool): + unit_scale = [unit_scale] * len(all_axes) + + # Compute min-max value for auto scaling + min_vals = [] + max_vals = [] + for sub_ax in all_axes: + if ax_type == "x": + min_v, max_v = sub_ax.get_xlim() + else: + min_v, max_v = sub_ax.get_ylim() + min_vals.append(min_v) + max_vals.append(max_v) + + formatter_opts[ax_type] = { + "limit": limit, + "unit": unit, + "unit_scale": unit_scale, + "min_ax_vals": min_vals, + "max_ax_vals": max_vals, + } + + def signed_sqrt(x): + return np.sign(x) * np.sqrt(abs(x)) + + def signed_square(x): + return np.sign(x) * x**2 + + for i, sub_ax in enumerate(all_axes): + # Add data labels if there are multiple labels registered per sub_ax. _, labels = sub_ax.get_legend_handles_labels() if len(labels) > 1: sub_ax.legend(loc=self.style["legend_loc"]) - # Format x and y axis - for ax_type in ("x", "y"): - # Get axis formatter from drawing options - if ax_type == "x": - lim = self.figure_options.xlim - unit = self.figure_options.xval_unit - unit_scale = self.figure_options.xval_unit_scale - else: - lim = self.figure_options.ylim - unit = self.figure_options.yval_unit - unit_scale = self.figure_options.yval_unit_scale - - # Compute data range from auto scale - if not lim: - v0 = np.nan - v1 = np.nan - for sub_ax in all_axes: - if ax_type == "x": - this_v0, this_v1 = sub_ax.get_xlim() - else: - this_v0, this_v1 = sub_ax.get_ylim() - v0 = np.nanmin([v0, this_v0]) - v1 = np.nanmax([v1, this_v1]) - lim = (v0, v1) - - # Format scaling (log, quadratic, linear, etc.) - def signed_sqrt(x): - return np.sign(x) * np.sqrt(abs(x)) - - def signed_square(x): - return np.sign(x) * x**2 + for ax_type in ("x", "y"): + limit = formatter_opts[ax_type]["limit"][i] + unit = formatter_opts[ax_type]["unit"][i] + unit_scale = formatter_opts[ax_type]["unit_scale"][i] + scale = self.figure_options.get(f"{ax_type}scale") + min_ax_vals = formatter_opts[ax_type]["min_ax_vals"] + max_ax_vals = formatter_opts[ax_type]["max_ax_vals"] + share_axis = self.figure_options.get(f"share{ax_type}") - for sub_ax in all_axes: if ax_type == "x": - scale_opt = self.figure_options.xscale - scale_function = sub_ax.set_xscale + mpl_setscale = sub_ax.set_xscale + mpl_axis_obj = getattr(sub_ax, "xaxis") + mpl_setlimit = sub_ax.set_xlim + mpl_share = sub_ax.sharex else: - scale_opt = self.figure_options.yscale - scale_function = sub_ax.set_yscale - - if scale_opt is not None: - if scale_opt == "quadratic": - scale_function("function", functions=(signed_square, signed_sqrt)) + mpl_setscale = sub_ax.set_yscale + mpl_axis_obj = getattr(sub_ax, "yaxis") + mpl_setlimit = sub_ax.set_ylim + mpl_share = sub_ax.sharey + + if limit is None: + if share_axis: + limit = min(min_ax_vals), max(max_ax_vals) else: - scale_function(scale_opt) - - # Format axis number notation - if unit and unit_scale: - # If value is specified, automatically scale axis magnitude - # and write prefix to axis label, i.e. 1e3 Hz -> 1 kHz - maxv = max(np.abs(lim[0]), np.abs(lim[1])) - try: - scaled_maxv, prefix = detach_prefix(maxv, decimal=3) - prefactor = scaled_maxv / maxv - except ValueError: - prefix = "" - prefactor = 1 - - formatter = MplDrawer.PrefixFormatter(prefactor) - units_str = f" [{prefix}{unit}]" - else: - # Use scientific notation with 3 digits, 1000 -> 1e3 - formatter = ScalarFormatter() - formatter.set_scientific(True) - formatter.set_powerlimits((-3, 3)) - - units_str = f" [{unit}]" if unit else "" + limit = min_ax_vals[i], max_ax_vals[i] - for sub_ax in all_axes: - if ax_type == "x": - ax = getattr(sub_ax, "xaxis") - tick_labels = sub_ax.get_xticklabels() + # Apply non linear axis spacing + if scale is not None: + if scale == "quadratic": + mpl_setscale("function", functions=(signed_square, signed_sqrt)) + else: + mpl_setscale(scale) + + # Create formatter for axis tick label notation + if unit and unit_scale: + # If value is specified, automatically scale axis magnitude + # and write prefix to axis label, i.e. 1e3 Hz -> 1 kHz + maxv = max(np.abs(limit[0]), np.abs(limit[1])) + try: + scaled_maxv, prefix = detach_prefix(maxv, decimal=3) + prefactor = scaled_maxv / maxv + except ValueError: + prefix = "" + prefactor = 1 + formatter = MplDrawer.PrefixFormatter(prefactor) + units_str = f" [{prefix}{unit}]" else: - ax = getattr(sub_ax, "yaxis") - tick_labels = sub_ax.get_yticklabels() - - if tick_labels: - # Set formatter only when tick labels exist - ax.set_major_formatter(formatter) + # Use scientific notation with 3 digits, 1000 -> 1e3 + formatter = ScalarFormatter() + formatter.set_scientific(True) + formatter.set_powerlimits((-3, 3)) + units_str = f" [{unit}]" if unit else "" + mpl_axis_obj.set_major_formatter(formatter) + + # Add units to axis label if both exist if units_str: - # Add units to label if both exist - label_txt_obj = ax.get_label() + label_txt_obj = mpl_axis_obj.get_label() label_str = label_txt_obj.get_text() if label_str: label_txt_obj.set_text(label_str + units_str) - # Auto-scale all axes to the first sub axis - if ax_type == "x": - # get_shared_y_axes() is immutable from matplotlib>=3.6.0. Must use Axis.sharey() - # instead, but this can only be called once per axis. Here we call sharey on all axes in - # a chain, which should have the same effect. - if len(all_axes) > 1: - for ax1, ax2 in zip(all_axes[1:], all_axes[0:-1]): - ax1.sharex(ax2) - all_axes[0].set_xlim(lim) - else: - # get_shared_y_axes() is immutable from matplotlib>=3.6.0. Must use Axis.sharey() - # instead, but this can only be called once per axis. Here we call sharey on all axes in - # a chain, which should have the same effect. - if len(all_axes) > 1: - for ax1, ax2 in zip(all_axes[1:], all_axes[0:-1]): - ax1.sharey(ax2) - all_axes[0].set_ylim(lim) + # Consider axis sharing among subplots + if share_axis: + if i == 0: + # Limit is set to the first axis only. + mpl_setlimit(limit) + else: + # get_shared_*_axes() is immutable from matplotlib>=3.6.0. + # Must use Axis.share*() instead, but this can only be called once per axis. + # Here we call share* on all axes in a chain, which should have the same effect. + mpl_share(all_axes[i - 1]) + else: + mpl_setlimit(limit) + # Add title if self.figure_options.figure_title is not None: self._axis.set_title( diff --git a/qiskit_experiments/visualization/plotters/base_plotter.py b/qiskit_experiments/visualization/plotters/base_plotter.py index c0b86eeb3d..2810184b8d 100644 --- a/qiskit_experiments/visualization/plotters/base_plotter.py +++ b/qiskit_experiments/visualization/plotters/base_plotter.py @@ -421,43 +421,49 @@ def _default_figure_options(cls) -> Options: """Return default figure options. Figure Options: - xlabel (Union[str, List[str]]): X-axis label string of the output figure. - If there are multiple columns in the canvas, this could be a list of - labels. - ylabel (Union[str, List[str]]): Y-axis label string of the output figure. - If there are multiple rows in the canvas, this could be a list of - labels. - xlim (Tuple[float, float]): Min and max value of the horizontal axis. - If not provided, it is automatically scaled based on the input data - points. - ylim (Tuple[float, float]): Min and max value of the vertical axis. - If not provided, it is automatically scaled based on the input data - points. - xval_unit (str): Unit of x values. No scaling prefix is needed here as this - is controlled by ``xval_unit_scale``. - yval_unit (str): Unit of y values. See ``xval_unit`` for details. - xval_unit_scale (bool): Whether to add an SI unit prefix to ``xval_unit`` if - needed. For example, when the x values represent time and + xlabel (Union[str, List[str]]): X-axis label string of the output figure. If + there are multiple columns in the canvas, this could be a list of labels. + ylabel (Union[str, List[str]]): Y-axis label string of the output figure. If + there are multiple rows in the canvas, this could be a list of labels. + xlim (Union[Tuple[float, float], List[Tuple[float, float]]): Min and max value + of the horizontal axis. If not provided, it is automatically scaled based + on the input data points. If there are multiple columns in the canvas, + this could be a list of xlims. + ylim (Union[Tuple[float, float], List[Tuple[float, float]]): Min and max value + of the vertical axis. If not provided, it is automatically scaled based + on the input data points. If there are multiple rows in the canvas, + this could be a list of ylims. + xval_unit (Union[str, List[str]]): Unit of x values. + No scaling prefix is needed here as this is controlled by ``xval_unit_scale``. + If there are multiple columns in the canvas, this could be a list of xval_units. + yval_unit (Union[str, List[str]]): Unit of y values. + No scaling prefix is needed here as this is controlled by ``yval_unit_scale``. + If there are multiple rows in the canvas, this could be a list of yval_units. + xval_unit_scale (Union[bool, List[bool]]): Whether to add an SI unit prefix to + ``xval_unit`` if needed. For example, when the x values represent time and ``xval_unit="s"``, ``xval_unit_scale=True`` adds an SI unit prefix to ``"s"`` based on X values of plotted data. In the output figure, the prefix is automatically selected based on the maximum value in this axis. If your x values are in [1e-3, 1e-4], they are displayed as [1 ms, 10 ms]. By default, this option is set to ``True``. If ``False`` is provided, the axis numbers will be displayed in the scientific notation. - yval_unit_scale (bool): Whether to add an SI unit prefix to ``yval_unit`` if - needed. See ``xval_unit_scale`` for details. - xscale (str): A parameter to the function ``set_xscale()`` to set the - x-axis scaling. Available options are ``log``, ``linear``, ``symlog``, - ``logit``, and ``quadratic``. - yscale (str): See ``xscale`` for details. + If there are multiple columns in the canvas, this could be a list of xval_unit_scale. + yval_unit_scale (Union[bool, List[bool]]): Whether to add an SI unit prefix to + ``yval_unit`` if needed. See ``xval_unit_scale`` for details. + If there are multiple rows in the canvas, this could be a list of yval_unit_scale. + xscale (str): The scaling of the x-axis, such as ``log`` or ``linear``. + yscale (str): The scaling of the y-axis, such as ``log`` or ``linear``. figure_title (str): Title of the figure. Defaults to None, i.e. nothing is shown. - series_params (Dict[SeriesName, Dict[str, Any]]): A dictionary of plot - parameters for each series. This is keyed on the name for each series. - Sub-dictionary is expected to have following three configurations, - "canvas", "color", and "symbol"; "canvas" is the integer index of axis - (when multi-canvas plot is set), "color" is the color of the curve, and - "symbol" is the marker Style of the curve for scatter plots. + sharex (bool): Set True to share x-axis ticks among sub-plots. + sharey (bool): Set True to share y-axis ticks among sub-plots. + series_params (Dict[str, Dict[str, Any]]): A dictionary of parameters for + each series. This is keyed on the name for each series. Sub-dictionary + is expected to have the following three configurations, "canvas", + "color", "symbol" and "label"; "canvas" is the integer index of axis + (when multi-canvas plot is set), "color" is the color of the drawn + graphics, "symbol" is the series marker style for scatter plots, and + "label" is a user provided series label that appears in the legend. """ options = Options( xlabel=None, @@ -470,12 +476,13 @@ def _default_figure_options(cls) -> Options: yval_unit_scale=True, xscale=None, yscale=None, + sharex=True, + sharey=True, figure_title=None, series_params={}, ) - - options.set_validator("xscale", ["linear", "log", "symlog", "logit", "quadratic"]) - options.set_validator("yscale", ["linear", "log", "symlog", "logit", "quadratic"]) + options.set_validator("xscale", ["linear", "log", "symlog", "logit", "quadratic", None]) + options.set_validator("yscale", ["linear", "log", "symlog", "logit", "quadratic", None]) return options diff --git a/releasenotes/notes/add-support-for-visualization-with-unshared-axis-9f7bfe272353086b.yaml b/releasenotes/notes/add-support-for-visualization-with-unshared-axis-9f7bfe272353086b.yaml new file mode 100644 index 0000000000..8c801288ee --- /dev/null +++ b/releasenotes/notes/add-support-for-visualization-with-unshared-axis-9f7bfe272353086b.yaml @@ -0,0 +1,10 @@ +--- +features: + - | + The :class:`.MplDrawer` visualization backend has been upgraded so that + it can take list of options for ``xlim``, ``ylim``, ``xval_unit``, ``yval_unit``, + ``xval_unit_scale``, and ``yval_unit_scale``. New figure options + ``sharex`` and ``sharey`` are also added. The new options are used to unkink the + configuration of sub axes, and default to ``True`` for backward compatibility. + By disabling these options, an experiment author can write an analysis class that + generates a multi-axes figure with different plot ranges.