Skip to content

Commit

Permalink
Update mpl drawer backend to support unshared axes (#1242)
Browse files Browse the repository at this point in the history
### Summary

Experiment plotter can create a multi-sub canvas (axes) figure, but x
and y plot range (xlim/ylim in matplotlib) must be shared among sub
axes. This PR relaxes this limitation and allows a developer to write
analysis class that can generate figure containing axes with different
plot ranges.

### Details and comments

Several figure options are updated to take a list of options for each
axis. New options ``sharex`` and ``sharey`` are added to unlink the axis
settings. These new options are True by default for backward
compatibility.
  • Loading branch information
nkanazawa1989 authored Sep 19, 2023
1 parent b27f33c commit 9842de1
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 143 deletions.
43 changes: 30 additions & 13 deletions qiskit_experiments/visualization/drawers/base_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,27 +194,38 @@ def _default_figure_options(cls) -> Options:
there are multiple columns in the canvas, this could be a list of labels.
ylabel (Union[str, List[str]]): Y-axis label string of the output figure. If
there are multiple rows in the canvas, this could be a list of labels.
xlim (Tuple[float, float]): Min and max value of the horizontal axis. If not
provided, it is automatically scaled based on the input data points.
ylim (Tuple[float, float]): Min and max value of the vertical axis. If not
provided, it is automatically scaled based on the input data points.
xval_unit (str): Unit of x values. No scaling prefix is needed here as this
is controlled by ``xval_unit_scale``.
yval_unit (str): Unit of y values. See ``xval_unit`` for details.
xval_unit_scale (bool): Whether to add an SI unit prefix to ``xval_unit`` if
needed. For example, when the x values represent time and
xlim (Union[Tuple[float, float], List[Tuple[float, float]]): Min and max value
of the horizontal axis. If not provided, it is automatically scaled based
on the input data points. If there are multiple columns in the canvas,
this could be a list of xlims.
ylim (Union[Tuple[float, float], List[Tuple[float, float]]): Min and max value
of the vertical axis. If not provided, it is automatically scaled based
on the input data points. If there are multiple rows in the canvas,
this could be a list of ylims.
xval_unit (Union[str, List[str]]): Unit of x values.
No scaling prefix is needed here as this is controlled by ``xval_unit_scale``.
If there are multiple columns in the canvas, this could be a list of xval_units.
yval_unit (Union[str, List[str]]): Unit of y values.
No scaling prefix is needed here as this is controlled by ``yval_unit_scale``.
If there are multiple rows in the canvas, this could be a list of yval_units.
xval_unit_scale (Union[bool, List[bool]]): Whether to add an SI unit prefix to
``xval_unit`` if needed. For example, when the x values represent time and
``xval_unit="s"``, ``xval_unit_scale=True`` adds an SI unit prefix to
``"s"`` based on X values of plotted data. In the output figure, the
prefix is automatically selected based on the maximum value in this
axis. If your x values are in [1e-3, 1e-4], they are displayed as [1 ms,
10 ms]. By default, this option is set to ``True``. If ``False`` is
provided, the axis numbers will be displayed in the scientific notation.
yval_unit_scale (bool): Whether to add an SI unit prefix to ``yval_unit`` if
needed. See ``xval_unit_scale`` for details.
If there are multiple columns in the canvas, this could be a list of xval_unit_scale.
yval_unit_scale (Union[bool, List[bool]]): Whether to add an SI unit prefix to
``yval_unit`` if needed. See ``xval_unit_scale`` for details.
If there are multiple rows in the canvas, this could be a list of yval_unit_scale.
xscale (str): The scaling of the x-axis, such as ``log`` or ``linear``.
yscale (str): See ``xscale`` for details.
yscale (str): The scaling of the y-axis, such as ``log`` or ``linear``.
figure_title (str): Title of the figure. Defaults to None, i.e. nothing is
shown.
sharex (bool): Set True to share x-axis ticks among sub-plots.
sharey (bool): Set True to share y-axis ticks among sub-plots.
series_params (Dict[str, Dict[str, Any]]): A dictionary of parameters for
each series. This is keyed on the name for each series. Sub-dictionary
is expected to have the following three configurations, "canvas",
Expand All @@ -226,7 +237,7 @@ def _default_figure_options(cls) -> Options:
overwrites style parameters in ``default_style`` in :attr:`options`.
Defaults to an empty PlotStyle instance (i.e., ``PlotStyle()``).
"""
return Options(
options = Options(
xlabel=None,
ylabel=None,
xlim=None,
Expand All @@ -237,10 +248,16 @@ def _default_figure_options(cls) -> Options:
yval_unit_scale=True,
xscale=None,
yscale=None,
sharex=True,
sharey=True,
figure_title=None,
series_params={},
custom_style=PlotStyle(),
)
options.set_validator("xscale", ["linear", "log", "symlog", "logit", "quadratic", None])
options.set_validator("yscale", ["linear", "log", "symlog", "logit", "quadratic", None])

return options

def set_options(self, **fields):
"""Set the drawer options.
Expand Down
211 changes: 113 additions & 98 deletions qiskit_experiments/visualization/drawers/mpl_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

"""Curve drawer for matplotlib backend."""

import numbers
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -79,6 +80,9 @@ def initialize_canvas(self):
else:
axis = self.options.axis

sharex = self.figure_options.sharex
sharey = self.figure_options.sharey

n_rows, n_cols = self.options.subplots
n_subplots = n_cols * n_rows
if n_subplots > 1:
Expand All @@ -99,9 +103,9 @@ def initialize_canvas(self):
inset_ax_h,
]
sub_ax = axis.inset_axes(bounds, transform=axis.transAxes, zorder=1)
if j != 0:
if j != 0 and sharey:
# remove y axis except for most-left plot
sub_ax.set_yticklabels([])
sub_ax.yaxis.set_tick_params(labelleft=False)
else:
# this axis locates at left, write y-label
if self.figure_options.ylabel:
Expand All @@ -110,9 +114,9 @@ def initialize_canvas(self):
# Y label can be given as a list for each sub axis
label = label[i]
sub_ax.set_ylabel(label, fontsize=self.style["axis_label_size"])
if i != n_rows - 1:
if i != n_rows - 1 and sharex:
# remove x axis except for most-bottom plot
sub_ax.set_xticklabels([])
sub_ax.xaxis.set_tick_params(labelleft=False)
else:
# this axis locates at bottom, write x-label
if self.figure_options.xlabel:
Expand Down Expand Up @@ -143,115 +147,126 @@ def format_canvas(self):
else:
all_axes = [self._axis]

# Add data labels if there are multiple labels registered per sub_ax.
for sub_ax in all_axes:
# Get axis formatter from drawing options
formatter_opts = {}
for ax_type in ("x", "y"):
limit = self.figure_options.get(f"{ax_type}lim")
unit = self.figure_options.get(f"{ax_type}val_unit")
unit_scale = self.figure_options.get(f"{ax_type}val_unit_scale")

# Format options to a list for each axis
if limit is None or isinstance(limit[0], numbers.Number):
limit = [limit] * len(all_axes)
if unit is None or isinstance(unit, str):
unit = [unit] * len(all_axes)
if isinstance(unit_scale, bool):
unit_scale = [unit_scale] * len(all_axes)

# Compute min-max value for auto scaling
min_vals = []
max_vals = []
for sub_ax in all_axes:
if ax_type == "x":
min_v, max_v = sub_ax.get_xlim()
else:
min_v, max_v = sub_ax.get_ylim()
min_vals.append(min_v)
max_vals.append(max_v)

formatter_opts[ax_type] = {
"limit": limit,
"unit": unit,
"unit_scale": unit_scale,
"min_ax_vals": min_vals,
"max_ax_vals": max_vals,
}

def signed_sqrt(x):
return np.sign(x) * np.sqrt(abs(x))

def signed_square(x):
return np.sign(x) * x**2

for i, sub_ax in enumerate(all_axes):
# Add data labels if there are multiple labels registered per sub_ax.
_, labels = sub_ax.get_legend_handles_labels()
if len(labels) > 1:
sub_ax.legend(loc=self.style["legend_loc"])

# Format x and y axis
for ax_type in ("x", "y"):
# Get axis formatter from drawing options
if ax_type == "x":
lim = self.figure_options.xlim
unit = self.figure_options.xval_unit
unit_scale = self.figure_options.xval_unit_scale
else:
lim = self.figure_options.ylim
unit = self.figure_options.yval_unit
unit_scale = self.figure_options.yval_unit_scale

# Compute data range from auto scale
if not lim:
v0 = np.nan
v1 = np.nan
for sub_ax in all_axes:
if ax_type == "x":
this_v0, this_v1 = sub_ax.get_xlim()
else:
this_v0, this_v1 = sub_ax.get_ylim()
v0 = np.nanmin([v0, this_v0])
v1 = np.nanmax([v1, this_v1])
lim = (v0, v1)

# Format scaling (log, quadratic, linear, etc.)
def signed_sqrt(x):
return np.sign(x) * np.sqrt(abs(x))

def signed_square(x):
return np.sign(x) * x**2
for ax_type in ("x", "y"):
limit = formatter_opts[ax_type]["limit"][i]
unit = formatter_opts[ax_type]["unit"][i]
unit_scale = formatter_opts[ax_type]["unit_scale"][i]
scale = self.figure_options.get(f"{ax_type}scale")
min_ax_vals = formatter_opts[ax_type]["min_ax_vals"]
max_ax_vals = formatter_opts[ax_type]["max_ax_vals"]
share_axis = self.figure_options.get(f"share{ax_type}")

for sub_ax in all_axes:
if ax_type == "x":
scale_opt = self.figure_options.xscale
scale_function = sub_ax.set_xscale
mpl_setscale = sub_ax.set_xscale
mpl_axis_obj = getattr(sub_ax, "xaxis")
mpl_setlimit = sub_ax.set_xlim
mpl_share = sub_ax.sharex
else:
scale_opt = self.figure_options.yscale
scale_function = sub_ax.set_yscale

if scale_opt is not None:
if scale_opt == "quadratic":
scale_function("function", functions=(signed_square, signed_sqrt))
mpl_setscale = sub_ax.set_yscale
mpl_axis_obj = getattr(sub_ax, "yaxis")
mpl_setlimit = sub_ax.set_ylim
mpl_share = sub_ax.sharey

if limit is None:
if share_axis:
limit = min(min_ax_vals), max(max_ax_vals)
else:
scale_function(scale_opt)

# Format axis number notation
if unit and unit_scale:
# If value is specified, automatically scale axis magnitude
# and write prefix to axis label, i.e. 1e3 Hz -> 1 kHz
maxv = max(np.abs(lim[0]), np.abs(lim[1]))
try:
scaled_maxv, prefix = detach_prefix(maxv, decimal=3)
prefactor = scaled_maxv / maxv
except ValueError:
prefix = ""
prefactor = 1

formatter = MplDrawer.PrefixFormatter(prefactor)
units_str = f" [{prefix}{unit}]"
else:
# Use scientific notation with 3 digits, 1000 -> 1e3
formatter = ScalarFormatter()
formatter.set_scientific(True)
formatter.set_powerlimits((-3, 3))

units_str = f" [{unit}]" if unit else ""
limit = min_ax_vals[i], max_ax_vals[i]

for sub_ax in all_axes:
if ax_type == "x":
ax = getattr(sub_ax, "xaxis")
tick_labels = sub_ax.get_xticklabels()
# Apply non linear axis spacing
if scale is not None:
if scale == "quadratic":
mpl_setscale("function", functions=(signed_square, signed_sqrt))
else:
mpl_setscale(scale)

# Create formatter for axis tick label notation
if unit and unit_scale:
# If value is specified, automatically scale axis magnitude
# and write prefix to axis label, i.e. 1e3 Hz -> 1 kHz
maxv = max(np.abs(limit[0]), np.abs(limit[1]))
try:
scaled_maxv, prefix = detach_prefix(maxv, decimal=3)
prefactor = scaled_maxv / maxv
except ValueError:
prefix = ""
prefactor = 1
formatter = MplDrawer.PrefixFormatter(prefactor)
units_str = f" [{prefix}{unit}]"
else:
ax = getattr(sub_ax, "yaxis")
tick_labels = sub_ax.get_yticklabels()

if tick_labels:
# Set formatter only when tick labels exist
ax.set_major_formatter(formatter)
# Use scientific notation with 3 digits, 1000 -> 1e3
formatter = ScalarFormatter()
formatter.set_scientific(True)
formatter.set_powerlimits((-3, 3))
units_str = f" [{unit}]" if unit else ""
mpl_axis_obj.set_major_formatter(formatter)

# Add units to axis label if both exist
if units_str:
# Add units to label if both exist
label_txt_obj = ax.get_label()
label_txt_obj = mpl_axis_obj.get_label()
label_str = label_txt_obj.get_text()
if label_str:
label_txt_obj.set_text(label_str + units_str)

# Auto-scale all axes to the first sub axis
if ax_type == "x":
# get_shared_y_axes() is immutable from matplotlib>=3.6.0. Must use Axis.sharey()
# instead, but this can only be called once per axis. Here we call sharey on all axes in
# a chain, which should have the same effect.
if len(all_axes) > 1:
for ax1, ax2 in zip(all_axes[1:], all_axes[0:-1]):
ax1.sharex(ax2)
all_axes[0].set_xlim(lim)
else:
# get_shared_y_axes() is immutable from matplotlib>=3.6.0. Must use Axis.sharey()
# instead, but this can only be called once per axis. Here we call sharey on all axes in
# a chain, which should have the same effect.
if len(all_axes) > 1:
for ax1, ax2 in zip(all_axes[1:], all_axes[0:-1]):
ax1.sharey(ax2)
all_axes[0].set_ylim(lim)
# Consider axis sharing among subplots
if share_axis:
if i == 0:
# Limit is set to the first axis only.
mpl_setlimit(limit)
else:
# get_shared_*_axes() is immutable from matplotlib>=3.6.0.
# Must use Axis.share*() instead, but this can only be called once per axis.
# Here we call share* on all axes in a chain, which should have the same effect.
mpl_share(all_axes[i - 1])
else:
mpl_setlimit(limit)

# Add title
if self.figure_options.figure_title is not None:
self._axis.set_title(
Expand Down
Loading

0 comments on commit 9842de1

Please sign in to comment.