From ab3f17f6cdda96248e6b3c48d78b1d8714fe42b1 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Tue, 8 Dec 2020 09:49:19 -0500 Subject: [PATCH 1/3] Move matplotlib imports to runtime This commit moves all the matplotlib imports to runtime. Prior to this commit when importing the qiskit namespace matplotlib would always be imported if it was installed. This however adds a noticeable bottleneck to the import time, even if visualization isn't used. To avoid this overhead we can lazily import matplotlib at runtime when it is available. This commit makes this change so that matplotlib is only imported at runtime when visualization methods using it are called. This does slow down the first visualization function call using matplotlib, but since visualization isn't a performance critical path this an acceptable tradeoff. For backwards compibility the HAS_MATPLOTLIB module attribute has changed from a bool to an object of a new class HasMatplotlibWrapper, which implements a single method __bool__ which returns true if matplotlib is installed, otherwise it's false. However it does this at runtime instead of at import time. This way it be used as a replacement for the previous HAS_MATPLOTLIB boolean variable but without forcing matplotlib be imported with the entire package. --- qiskit/visualization/counts_visualization.py | 8 +- qiskit/visualization/gate_map.py | 23 +++-- qiskit/visualization/matplotlib.py | 30 +++++-- qiskit/visualization/pulse/matplotlib.py | 20 +++-- qiskit/visualization/pulse_visualization.py | 8 +- qiskit/visualization/state_visualization.py | 91 +++++++++----------- 6 files changed, 98 insertions(+), 82 deletions(-) diff --git a/qiskit/visualization/counts_visualization.py b/qiskit/visualization/counts_visualization.py index 49926fafabb5..f476d56b5367 100644 --- a/qiskit/visualization/counts_visualization.py +++ b/qiskit/visualization/counts_visualization.py @@ -22,11 +22,6 @@ from .matplotlib import HAS_MATPLOTLIB from .exceptions import VisualizationError -if HAS_MATPLOTLIB: - from matplotlib import get_backend - import matplotlib.pyplot as plt - from matplotlib.ticker import MaxNLocator - def hamming_distance(str1, str2): """Calculate the Hamming distance between two bit strings @@ -100,6 +95,9 @@ def plot_histogram(data, figsize=(7, 5), color=None, number_to_keep=None, """ if not HAS_MATPLOTLIB: raise ImportError('Must have Matplotlib installed.') + from matplotlib import get_backend + import matplotlib.pyplot as plt + from matplotlib.ticker import MaxNLocator if sort not in VALID_SORTS: raise VisualizationError("Value of sort option, %s, isn't a " "valid choice. Must be 'asc', " diff --git a/qiskit/visualization/gate_map.py b/qiskit/visualization/gate_map.py index 038134ae5a67..f212ba25d50d 100644 --- a/qiskit/visualization/gate_map.py +++ b/qiskit/visualization/gate_map.py @@ -18,14 +18,6 @@ from .matplotlib import HAS_MATPLOTLIB from .exceptions import VisualizationError -if HAS_MATPLOTLIB: - import matplotlib - from matplotlib import get_backend - import matplotlib.pyplot as plt # pylint: disable=import-error - import matplotlib.patches as mpatches - import matplotlib.gridspec as gridspec - from matplotlib import ticker - class _GraphDist(): """Transform the circles properly for non-square axes. @@ -120,6 +112,12 @@ def plot_gate_map(backend, figsize=None, if not HAS_MATPLOTLIB: raise ImportError('Must have Matplotlib installed. To install, ' 'run "pip install matplotlib".') + import matplotlib + from matplotlib import get_backend + import matplotlib.pyplot as plt # pylint: disable=import-error + import matplotlib.patches as mpatches + import matplotlib.gridspec as gridspec + from matplotlib import ticker if backend.configuration().simulator: raise QiskitError('Requires a device backend, not simulator.') @@ -432,6 +430,15 @@ def plot_error_map(backend, figsize=(12, 9), show_title=True): except ImportError: raise ImportError('Must have seaborn installed to use plot_error_map. ' 'To install, run "pip install seaborn".') + if not HAS_MATPLOTLIB: + raise ImportError('Must have Matplotlib installed. To install, ' + 'run "pip install matplotlib".') + import matplotlib + from matplotlib import get_backend + import matplotlib.pyplot as plt # pylint: disable=import-error + import matplotlib.patches as mpatches + import matplotlib.gridspec as gridspec + from matplotlib import ticker color_map = sns.cubehelix_palette(reverse=True, as_cmap=True) diff --git a/qiskit/visualization/matplotlib.py b/qiskit/visualization/matplotlib.py index b0c781844865..2419b75e0132 100644 --- a/qiskit/visualization/matplotlib.py +++ b/qiskit/visualization/matplotlib.py @@ -24,14 +24,6 @@ import numpy as np -try: - from matplotlib import get_backend - from matplotlib import patches - from matplotlib import pyplot as plt - - HAS_MATPLOTLIB = True -except ImportError: - HAS_MATPLOTLIB = False try: from pylatexenc.latex2text import LatexNodes2Text @@ -130,6 +122,8 @@ def __init__(self, qregs, cregs, ops, if not HAS_MATPLOTLIB: raise ImportError('The class MatplotlibDrawer needs matplotlib. ' 'To install, run "pip install matplotlib".') + from matplotlib import patches + from matplotlib import pyplot as plt if not HAS_PYLATEX: raise ImportError('The class MatplotlibDrawer needs pylatexenc. ' 'to install, run "pip install pylatexenc".') @@ -643,6 +637,7 @@ def draw(self, filename=None, verbose=False): self._figure.savefig(filename, dpi=self._style['dpi'], bbox_inches='tight', facecolor=self._figure.get_facecolor()) if self._return_fig: + from matplotlib import get_backend if get_backend() in ['module://ipykernel.pylab.backend_inline', 'nbAgg']: plt.close(self._figure) @@ -1080,3 +1075,22 @@ def _draw_ops(self, verbose=False): self._ax.text(x_coord, y_coord, str(ii + 1), ha='center', va='center', fontsize=sfs, color=self._style['tc'], clip_on=True, zorder=PORDER_TEXT) + + +class HasMatplotlibWrapper: + """Wrapper to lazily import matplotlib.""" + has_matplotlib = False + + def __bool__(self): + if not self.has_matplotlib: + try: + from matplotlib import get_backend + from matplotlib import patches + from matplotlib import pyplot as plt + self.has_matplotlib = True + except ImportError: + self.has_matplotlib = False + return self.has_matplotlib + + +HAS_MATPLOTLIB = HasMatplotlibWrapper() diff --git a/qiskit/visualization/pulse/matplotlib.py b/qiskit/visualization/pulse/matplotlib.py index 29b49c9336ce..141061e90e9b 100644 --- a/qiskit/visualization/pulse/matplotlib.py +++ b/qiskit/visualization/pulse/matplotlib.py @@ -19,12 +19,7 @@ import numpy as np -try: - from matplotlib import pyplot as plt, gridspec - HAS_MATPLOTLIB = True -except ImportError: - HAS_MATPLOTLIB = False - +from qiskit.visualization.matplotlib import HAS_MATPLOTLIB from qiskit.visualization.pulse.qcstyle import PulseStyle, SchedStyle from qiskit.visualization.pulse.interpolation import step_wise from qiskit.pulse.channels import (DriveChannel, ControlChannel, @@ -289,6 +284,12 @@ def draw(self, pulse: Waveform, """ # If these self.style.dpi or self.style.figsize are None, they will # revert back to their default rcParam keys. + if not HAS_MATPLOTLIB: + raise ImportError("Matplotlib needs to be installed to use " + "WaveformDrawer. It can be installed with " + "'pip install matplotlib'") + + from matplotlib import pyplot as plt figure = plt.figure(dpi=self.style.dpi, figsize=self.style.figsize) interp_method = interp_method or step_wise @@ -340,6 +341,13 @@ def __init__(self, style: SchedStyle): Args: style: Style sheet for pulse schedule visualization. """ + if not HAS_MATPLOTLIB: + raise ImportError("Matplotlib needs to be installed to use " + "ScheduleDrawer. It can be installed with " + "'pip install matplotlib'") + + from matplotlib import pyplot as plt + from matplotlib import gridspec self.style = style or SchedStyle() def _build_channels(self, schedule: ScheduleComponent, diff --git a/qiskit/visualization/pulse_visualization.py b/qiskit/visualization/pulse_visualization.py index 2a098b0f8604..bdb35aef553e 100644 --- a/qiskit/visualization/pulse_visualization.py +++ b/qiskit/visualization/pulse_visualization.py @@ -22,12 +22,7 @@ from qiskit.visualization.pulse.qcstyle import PulseStyle, SchedStyle from qiskit.visualization.exceptions import VisualizationError from qiskit.visualization.pulse import matplotlib as _matplotlib - -try: - from matplotlib import get_backend - HAS_MATPLOTLIB = True -except ImportError: - HAS_MATPLOTLIB = False +from qiskit.visualization.matplotlib import HAS_MATPLOTLIB def pulse_drawer(data: Union[Waveform, Union[Schedule, Instruction]], @@ -145,6 +140,7 @@ def pulse_drawer(data: Union[Waveform, Union[Schedule, Instruction]], """ if not HAS_MATPLOTLIB: raise ImportError('Must have Matplotlib installed.') + from matplotlib import get_backend if isinstance(data, (SamplePulse, Waveform)): drawer = _matplotlib.WaveformDrawer(style=style) image = drawer.draw(data, dt=dt, interp_method=interp_method, scale=scale) diff --git a/qiskit/visualization/state_visualization.py b/qiskit/visualization/state_visualization.py index 608eeddb63ab..3244120b9f1d 100644 --- a/qiskit/visualization/state_visualization.py +++ b/qiskit/visualization/state_visualization.py @@ -26,37 +26,41 @@ from qiskit.utils.deprecation import deprecate_arguments from .matplotlib import HAS_MATPLOTLIB -if HAS_MATPLOTLIB: - from matplotlib import get_backend - from matplotlib import pyplot as plt - from matplotlib.patches import FancyArrowPatch - from matplotlib.patches import Circle - import matplotlib.colors as mcolors - from matplotlib.colors import Normalize, LightSource - import matplotlib.gridspec as gridspec - from mpl_toolkits.mplot3d import proj3d - from mpl_toolkits.mplot3d.art3d import Poly3DCollection - from qiskit.visualization.exceptions import VisualizationError - from qiskit.visualization.bloch import Bloch - from qiskit.visualization.utils import _bloch_multivector_data, _paulivec_data - from qiskit.circuit.tools.pi_check import pi_check - - -if HAS_MATPLOTLIB: - class Arrow3D(FancyArrowPatch): - """Standard 3D arrow.""" - - def __init__(self, xs, ys, zs, *args, **kwargs): - """Create arrow.""" - FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) - self._verts3d = xs, ys, zs - - def draw(self, renderer): - """Draw the arrow.""" - xs3d, ys3d, zs3d = self._verts3d - xs, ys, _ = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) - self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) - FancyArrowPatch.draw(self, renderer) + +def import_mpl(): + if HAS_MATPLOTLIB: + from matplotlib import get_backend + from matplotlib import pyplot as plt + from matplotlib.patches import FancyArrowPatch + from matplotlib.patches import Circle + import matplotlib.colors as mcolors + from matplotlib.colors import Normalize, LightSource + import matplotlib.gridspec as gridspec + from mpl_toolkits.mplot3d import proj3d + from mpl_toolkits.mplot3d.art3d import Poly3DCollection + from qiskit.visualization.exceptions import VisualizationError + from qiskit.visualization.bloch import Bloch + from qiskit.visualization.utils import _bloch_multivector_data, _paulivec_data + from qiskit.circuit.tools.pi_check import pi_check + + class Arrow3D(FancyArrowPatch): + """Standard 3D arrow.""" + + def __init__(self, xs, ys, zs, *args, **kwargs): + """Create arrow.""" + FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) + self._verts3d = xs, ys, zs + + def draw(self, renderer): + """Draw the arrow.""" + xs3d, ys3d, zs3d = self._verts3d + xs, ys, _ = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) + self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) + FancyArrowPatch.draw(self, renderer) + + else: + raise ImportError('Must have Matplotlib installed. To install, run ' + '"pip install matplotlib".') @deprecate_arguments({'rho': 'state'}) @@ -104,9 +108,8 @@ def plot_state_hinton(state, title='', figsize=None, ax_real=None, ax_imag=None, state = DensityMatrix.from_instruction(qc) plot_state_hinton(state, title="New Hinton Plot") """ - if not HAS_MATPLOTLIB: - raise ImportError('Must have Matplotlib installed. To install, run ' - '"pip install matplotlib".') + import_mpl() + # Figure data rho = DensityMatrix(state) num = rho.num_qubits @@ -214,9 +217,7 @@ def plot_bloch_vector(bloch, title="", ax=None, figsize=None, coord_type="cartes plot_bloch_vector([0,1,0], title="New Bloch Sphere") """ - if not HAS_MATPLOTLIB: - raise ImportError('Must have Matplotlib installed. To install, run ' - '"pip install matplotlib".') + import_mpl() if figsize is None: figsize = (5, 5) B = Bloch(axes=ax) @@ -271,9 +272,7 @@ def plot_bloch_multivector(state, title='', figsize=None, *, rho=None): state = Statevector.from_instruction(qc) plot_bloch_multivector(state, title="New Bloch Multivector") """ - if not HAS_MATPLOTLIB: - raise ImportError('Must have Matplotlib installed. To install, run "pip install ' - 'matplotlib".') + import_mpl() # Data bloch_data = _bloch_multivector_data(state) num = len(bloch_data) @@ -344,9 +343,7 @@ def plot_state_city(state, title="", figsize=None, color=None, plot_state_city(state, color=['midnightblue', 'midnightblue'], title="New State City") """ - if not HAS_MATPLOTLIB: - raise ImportError('Must have Matplotlib installed. To install, run "pip install ' - 'matplotlib".') + import_mpl() rho = DensityMatrix(state) num = rho.num_qubits if num is None: @@ -542,9 +539,7 @@ def plot_state_paulivec(state, title="", figsize=None, color=None, ax=None, *, r plot_state_paulivec(state, color='midnightblue', title="New PauliVec plot") """ - if not HAS_MATPLOTLIB: - raise ImportError('Must have Matplotlib installed. To install, run "pip install ' - 'matplotlib".') + import_mpl() labels, values = _paulivec_data(state) numelem = len(values) @@ -687,9 +682,7 @@ def plot_state_qsphere(state, figsize=None, ax=None, show_state_labels=True, state = Statevector.from_instruction(qc) plot_state_qsphere(state) """ - if not HAS_MATPLOTLIB: - raise ImportError('Must have Matplotlib installed. To install, run "pip install ' - 'matplotlib".') + import_mpl() try: import seaborn as sns except ImportError: From f989ed8235983090e5a07e881649a2e06d9fc90a Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Wed, 27 Jan 2021 16:21:46 -0500 Subject: [PATCH 2/3] Fix import errors --- qiskit/visualization/matplotlib.py | 24 +++-- qiskit/visualization/state_visualization.py | 112 +++++++++++++------- 2 files changed, 84 insertions(+), 52 deletions(-) diff --git a/qiskit/visualization/matplotlib.py b/qiskit/visualization/matplotlib.py index 3f3d35933bce..d0170e89053c 100644 --- a/qiskit/visualization/matplotlib.py +++ b/qiskit/visualization/matplotlib.py @@ -123,7 +123,9 @@ def __init__(self, qregs, cregs, ops, raise ImportError('The class MatplotlibDrawer needs matplotlib. ' 'To install, run "pip install matplotlib".') from matplotlib import patches + self.patches_mod = patches from matplotlib import pyplot as plt + self.plt_mod = plt if not HAS_PYLATEX: raise ImportError('The class MatplotlibDrawer needs pylatexenc. ' 'to install, run "pip install pylatexenc".') @@ -294,7 +296,7 @@ def _get_text_width(self, text, fontsize, param=False): return 0.0 if self._renderer: - t = plt.text(0.5, 0.5, text, fontsize=fontsize) + t = self.plt_mod.text(0.5, 0.5, text, fontsize=fontsize) return t.get_window_extent(renderer=self._renderer).width / 60.0 else: math_mode_match = self._mathmode_regex.search(text) @@ -417,7 +419,7 @@ def _multiqubit_gate(self, xy, fc=None, ec=None, gt=None, sc=None, text='', subt qubit_span = abs(ypos) - abs(ypos_max) + 1 height = HIG + (qubit_span - 1) - box = patches.Rectangle( + box = self.patches_mod.Rectangle( xy=(xpos - 0.5 * wid, ypos - .5 * HIG), width=wid, height=height, fc=fc, ec=ec, linewidth=self._lwidth15, zorder=PORDER_GATE) self._ax.add_patch(box) @@ -452,7 +454,7 @@ def _gate(self, xy, fc=None, ec=None, gt=None, sc=None, text='', subtext=''): sub_width = self._get_text_width(subtext, sfs, param=True) wid = max((text_width, sub_width, WID)) - box = patches.Rectangle(xy=(xpos - 0.5 * wid, ypos - 0.5 * HIG), + box = self.patches_mod.Rectangle(xy=(xpos - 0.5 * wid, ypos - 0.5 * HIG), width=wid, height=HIG, fc=fc, ec=ec, linewidth=self._lwidth15, zorder=PORDER_GATE) self._ax.add_patch(box) @@ -509,7 +511,7 @@ def _measure(self, qxy, cxy, cid, fc=None, ec=None, gt=None, sc=None): self._gate(qxy, fc=fc, ec=ec, gt=gt, sc=sc) # add measure symbol - arc = patches.Arc(xy=(qx, qy - 0.15 * HIG), width=WID * 0.7, + arc = self.patches_mod.Arc(xy=(qx, qy - 0.15 * HIG), width=WID * 0.7, height=HIG * 0.7, theta1=0, theta2=180, fill=False, ec=gt, linewidth=self._lwidth2, zorder=PORDER_GATE) self._ax.add_patch(arc) @@ -517,7 +519,7 @@ def _measure(self, qxy, cxy, cid, fc=None, ec=None, gt=None, sc=None): color=gt, linewidth=self._lwidth2, zorder=PORDER_GATE) # arrow self._line(qxy, [cx, cy + 0.35 * WID], lc=self._style['cc'], ls=self._style['cline']) - arrowhead = patches.Polygon(((cx - 0.20 * WID, cy + 0.35 * WID), + arrowhead = self.patches_mod.Polygon(((cx - 0.20 * WID, cy + 0.35 * WID), (cx + 0.20 * WID, cy + 0.35 * WID), (cx, cy + 0.04)), fc=self._style['cc'], ec=None) self._ax.add_artist(arrowhead) @@ -531,13 +533,13 @@ def _conditional(self, xy, istrue=False): xpos, ypos = xy fc = self._style['lc'] if istrue else self._style['bg'] - box = patches.Circle(xy=(xpos, ypos), radius=WID * 0.15, fc=fc, + box = self.patches_mod.Circle(xy=(xpos, ypos), radius=WID * 0.15, fc=fc, ec=self._style['lc'], linewidth=self._lwidth15, zorder=PORDER_GATE) self._ax.add_patch(box) def _ctrl_qubit(self, xy, fc=None, ec=None, tc=None, text='', text_top=None): xpos, ypos = xy - box = patches.Circle(xy=(xpos, ypos), radius=WID * 0.15, + box = self.patches_mod.Circle(xy=(xpos, ypos), radius=WID * 0.15, fc=fc, ec=ec, linewidth=self._lwidth15, zorder=PORDER_GATE) self._ax.add_patch(box) # display the control label at the top or bottom if there is one @@ -577,7 +579,7 @@ def _set_ctrl_bits(self, ctrl_state, num_ctrl_qubits, qbit, ec=None, tc=None, def _x_tgt_qubit(self, xy, ec=None, ac=None): linewidth = self._lwidth2 xpos, ypos = xy - box = patches.Circle(xy=(xpos, ypos), radius=HIG * 0.35, + box = self.patches_mod.Circle(xy=(xpos, ypos), radius=HIG * 0.35, fc=ec, ec=ec, linewidth=linewidth, zorder=PORDER_GATE) self._ax.add_patch(box) @@ -605,7 +607,7 @@ def _barrier(self, config): self._ax.plot([xpos, xpos], [ypos + 0.5, ypos - 0.5], linewidth=self._scale, linestyle="dashed", color=self._style['lc'], zorder=PORDER_TEXT) - box = patches.Rectangle(xy=(xpos - (0.3 * WID), ypos - 0.5), + box = self.patches_mod.Rectangle(xy=(xpos - (0.3 * WID), ypos - 0.5), width=0.6 * WID, height=1, fc=self._style['bc'], ec=None, alpha=0.6, linewidth=self._lwidth15, zorder=PORDER_GRAY) @@ -630,7 +632,7 @@ def draw(self, filename=None, verbose=False): self._figure.set_size_inches(self._style['figwidth'], self._style['figwidth'] * fig_h / fig_w) if self._global_phase: - plt.text(_xl, _yt, 'Global Phase: %s' % pi_check(self._global_phase, + self.plt_mod.text(_xl, _yt, 'Global Phase: %s' % pi_check(self._global_phase, output='mpl')) if filename: @@ -640,7 +642,7 @@ def draw(self, filename=None, verbose=False): from matplotlib import get_backend if get_backend() in ['module://ipykernel.pylab.backend_inline', 'nbAgg']: - plt.close(self._figure) + self.plt_mod.close(self._figure) return self._figure def _draw_regs(self): diff --git a/qiskit/visualization/state_visualization.py b/qiskit/visualization/state_visualization.py index 3244120b9f1d..add19835ed5d 100644 --- a/qiskit/visualization/state_visualization.py +++ b/qiskit/visualization/state_visualization.py @@ -26,41 +26,9 @@ from qiskit.utils.deprecation import deprecate_arguments from .matplotlib import HAS_MATPLOTLIB - -def import_mpl(): - if HAS_MATPLOTLIB: - from matplotlib import get_backend - from matplotlib import pyplot as plt - from matplotlib.patches import FancyArrowPatch - from matplotlib.patches import Circle - import matplotlib.colors as mcolors - from matplotlib.colors import Normalize, LightSource - import matplotlib.gridspec as gridspec - from mpl_toolkits.mplot3d import proj3d - from mpl_toolkits.mplot3d.art3d import Poly3DCollection - from qiskit.visualization.exceptions import VisualizationError - from qiskit.visualization.bloch import Bloch - from qiskit.visualization.utils import _bloch_multivector_data, _paulivec_data - from qiskit.circuit.tools.pi_check import pi_check - - class Arrow3D(FancyArrowPatch): - """Standard 3D arrow.""" - - def __init__(self, xs, ys, zs, *args, **kwargs): - """Create arrow.""" - FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) - self._verts3d = xs, ys, zs - - def draw(self, renderer): - """Draw the arrow.""" - xs3d, ys3d, zs3d = self._verts3d - xs, ys, _ = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) - self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) - FancyArrowPatch.draw(self, renderer) - - else: - raise ImportError('Must have Matplotlib installed. To install, run ' - '"pip install matplotlib".') +from qiskit.visualization.exceptions import VisualizationError +from qiskit.visualization.utils import _bloch_multivector_data, _paulivec_data +from qiskit.circuit.tools.pi_check import pi_check @deprecate_arguments({'rho': 'state'}) @@ -108,7 +76,11 @@ def plot_state_hinton(state, title='', figsize=None, ax_real=None, ax_imag=None, state = DensityMatrix.from_instruction(qc) plot_state_hinton(state, title="New Hinton Plot") """ - import_mpl() + if not HAS_MATPLOTLIB: + raise ImportError('Must have Matplotlib installed. To install, run ' + '"pip install matplotlib".') + from matplotlib import pyplot as plt + from matplotlib import get_backend # Figure data rho = DensityMatrix(state) @@ -217,7 +189,13 @@ def plot_bloch_vector(bloch, title="", ax=None, figsize=None, coord_type="cartes plot_bloch_vector([0,1,0], title="New Bloch Sphere") """ - import_mpl() + if not HAS_MATPLOTLIB: + raise ImportError('Must have Matplotlib installed. To install, run ' + '"pip install matplotlib".') + from qiskit.visualization.bloch import Bloch + from matplotlib import get_backend + from matplotlib import pyplot as plt + if figsize is None: figsize = (5, 5) B = Bloch(axes=ax) @@ -272,7 +250,12 @@ def plot_bloch_multivector(state, title='', figsize=None, *, rho=None): state = Statevector.from_instruction(qc) plot_bloch_multivector(state, title="New Bloch Multivector") """ - import_mpl() + if not HAS_MATPLOTLIB: + raise ImportError('Must have Matplotlib installed. To install, run ' + '"pip install matplotlib".') + from matplotlib import get_backend + from matplotlib import pyplot as plt + # Data bloch_data = _bloch_multivector_data(state) num = len(bloch_data) @@ -343,7 +326,13 @@ def plot_state_city(state, title="", figsize=None, color=None, plot_state_city(state, color=['midnightblue', 'midnightblue'], title="New State City") """ - import_mpl() + if not HAS_MATPLOTLIB: + raise ImportError('Must have Matplotlib installed. To install, run ' + '"pip install matplotlib".') + from matplotlib import get_backend + from matplotlib import pyplot as plt + from mpl_toolkits.mplot3d.art3d import Poly3DCollection + rho = DensityMatrix(state) num = rho.num_qubits if num is None: @@ -539,7 +528,12 @@ def plot_state_paulivec(state, title="", figsize=None, color=None, ax=None, *, r plot_state_paulivec(state, color='midnightblue', title="New PauliVec plot") """ - import_mpl() + if not HAS_MATPLOTLIB: + raise ImportError('Must have Matplotlib installed. To install, run ' + '"pip install matplotlib".') + from matplotlib import get_backend + from matplotlib import pyplot as plt + labels, values = _paulivec_data(state) numelem = len(values) @@ -682,7 +676,31 @@ def plot_state_qsphere(state, figsize=None, ax=None, show_state_labels=True, state = Statevector.from_instruction(qc) plot_state_qsphere(state) """ - import_mpl() + if not HAS_MATPLOTLIB: + raise ImportError('Must have Matplotlib installed. To install, run ' + '"pip install matplotlib".') + + from mpl_toolkits.mplot3d import proj3d + from matplotlib.patches import FancyArrowPatch + import matplotlib.gridspec as gridspec + from matplotlib import pyplot as plt + from matplotlib.patches import Circle + + class Arrow3D(FancyArrowPatch): + """Standard 3D arrow.""" + + def __init__(self, xs, ys, zs, *args, **kwargs): + """Create arrow.""" + FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) + self._verts3d = xs, ys, zs + + def draw(self, renderer): + """Draw the arrow.""" + xs3d, ys3d, zs3d = self._verts3d + xs, ys, _ = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) + self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) + FancyArrowPatch.draw(self, renderer) + try: import seaborn as sns except ImportError: @@ -874,6 +892,11 @@ def generate_facecolors(x, y, z, dx, dy, dz, color): Returns: list: Shaded colors for bars. """ + if not HAS_MATPLOTLIB: + raise ImportError('Must have Matplotlib installed. To install, run ' + '"pip install matplotlib".') + import matplotlib.colors as mcolors + cuboid = np.array([ # -z ( @@ -990,6 +1013,13 @@ def _shade_colors(color, normals, lightsource=None): Shade *color* using normal vectors given by *normals*. *color* can also be an array of the same length as *normals*. """ + if not HAS_MATPLOTLIB: + raise ImportError('Must have Matplotlib installed. To install, run ' + '"pip install matplotlib".') + + from matplotlib.colors import Normalize, LightSource + import matplotlib.colors as mcolors + if lightsource is None: # chosen for backwards-compatibility lightsource = LightSource(azdeg=225, altdeg=19.4712) From c148dd1a1816523d46831ad65bf87befeed400e5 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Wed, 27 Jan 2021 16:41:36 -0500 Subject: [PATCH 3/3] Fix lint --- qiskit/visualization/gate_map.py | 4 --- qiskit/visualization/matplotlib.py | 30 +++++++++++---------- qiskit/visualization/pulse/matplotlib.py | 15 ++++++++--- qiskit/visualization/pulse_visualization.py | 5 +++- qiskit/visualization/state_visualization.py | 6 +++-- 5 files changed, 35 insertions(+), 25 deletions(-) diff --git a/qiskit/visualization/gate_map.py b/qiskit/visualization/gate_map.py index e31427504553..790720cba5e6 100644 --- a/qiskit/visualization/gate_map.py +++ b/qiskit/visualization/gate_map.py @@ -112,12 +112,9 @@ def plot_gate_map(backend, figsize=None, if not HAS_MATPLOTLIB: raise ImportError('Must have Matplotlib installed. To install, ' 'run "pip install matplotlib".') - import matplotlib from matplotlib import get_backend import matplotlib.pyplot as plt # pylint: disable=import-error import matplotlib.patches as mpatches - import matplotlib.gridspec as gridspec - from matplotlib import ticker if backend.configuration().simulator: raise QiskitError('Requires a device backend, not simulator.') @@ -440,7 +437,6 @@ def plot_error_map(backend, figsize=(12, 9), show_title=True): import matplotlib from matplotlib import get_backend import matplotlib.pyplot as plt # pylint: disable=import-error - import matplotlib.patches as mpatches import matplotlib.gridspec as gridspec from matplotlib import ticker diff --git a/qiskit/visualization/matplotlib.py b/qiskit/visualization/matplotlib.py index d0170e89053c..ca4b9068c79f 100644 --- a/qiskit/visualization/matplotlib.py +++ b/qiskit/visualization/matplotlib.py @@ -455,8 +455,8 @@ def _gate(self, xy, fc=None, ec=None, gt=None, sc=None, text='', subtext=''): wid = max((text_width, sub_width, WID)) box = self.patches_mod.Rectangle(xy=(xpos - 0.5 * wid, ypos - 0.5 * HIG), - width=wid, height=HIG, fc=fc, ec=ec, - linewidth=self._lwidth15, zorder=PORDER_GATE) + width=wid, height=HIG, fc=fc, ec=ec, + linewidth=self._lwidth15, zorder=PORDER_GATE) self._ax.add_patch(box) if text: @@ -512,16 +512,16 @@ def _measure(self, qxy, cxy, cid, fc=None, ec=None, gt=None, sc=None): # add measure symbol arc = self.patches_mod.Arc(xy=(qx, qy - 0.15 * HIG), width=WID * 0.7, - height=HIG * 0.7, theta1=0, theta2=180, fill=False, - ec=gt, linewidth=self._lwidth2, zorder=PORDER_GATE) + height=HIG * 0.7, theta1=0, theta2=180, fill=False, + ec=gt, linewidth=self._lwidth2, zorder=PORDER_GATE) self._ax.add_patch(arc) self._ax.plot([qx, qx + 0.35 * WID], [qy - 0.15 * HIG, qy + 0.20 * HIG], color=gt, linewidth=self._lwidth2, zorder=PORDER_GATE) # arrow self._line(qxy, [cx, cy + 0.35 * WID], lc=self._style['cc'], ls=self._style['cline']) arrowhead = self.patches_mod.Polygon(((cx - 0.20 * WID, cy + 0.35 * WID), - (cx + 0.20 * WID, cy + 0.35 * WID), - (cx, cy + 0.04)), fc=self._style['cc'], ec=None) + (cx + 0.20 * WID, cy + 0.35 * WID), + (cx, cy + 0.04)), fc=self._style['cc'], ec=None) self._ax.add_artist(arrowhead) # target if self._cregbundle: @@ -534,13 +534,14 @@ def _conditional(self, xy, istrue=False): fc = self._style['lc'] if istrue else self._style['bg'] box = self.patches_mod.Circle(xy=(xpos, ypos), radius=WID * 0.15, fc=fc, - ec=self._style['lc'], linewidth=self._lwidth15, zorder=PORDER_GATE) + ec=self._style['lc'], linewidth=self._lwidth15, + zorder=PORDER_GATE) self._ax.add_patch(box) def _ctrl_qubit(self, xy, fc=None, ec=None, tc=None, text='', text_top=None): xpos, ypos = xy box = self.patches_mod.Circle(xy=(xpos, ypos), radius=WID * 0.15, - fc=fc, ec=ec, linewidth=self._lwidth15, zorder=PORDER_GATE) + fc=fc, ec=ec, linewidth=self._lwidth15, zorder=PORDER_GATE) self._ax.add_patch(box) # display the control label at the top or bottom if there is one if text_top is True: @@ -580,8 +581,8 @@ def _x_tgt_qubit(self, xy, ec=None, ac=None): linewidth = self._lwidth2 xpos, ypos = xy box = self.patches_mod.Circle(xy=(xpos, ypos), radius=HIG * 0.35, - fc=ec, ec=ec, linewidth=linewidth, - zorder=PORDER_GATE) + fc=ec, ec=ec, linewidth=linewidth, + zorder=PORDER_GATE) self._ax.add_patch(box) # add '+' symbol @@ -608,9 +609,9 @@ def _barrier(self, config): linewidth=self._scale, linestyle="dashed", color=self._style['lc'], zorder=PORDER_TEXT) box = self.patches_mod.Rectangle(xy=(xpos - (0.3 * WID), ypos - 0.5), - width=0.6 * WID, height=1, - fc=self._style['bc'], ec=None, alpha=0.6, - linewidth=self._lwidth15, zorder=PORDER_GRAY) + width=0.6 * WID, height=1, + fc=self._style['bc'], ec=None, alpha=0.6, + linewidth=self._lwidth15, zorder=PORDER_GRAY) self._ax.add_patch(box) def draw(self, filename=None, verbose=False): @@ -633,7 +634,7 @@ def draw(self, filename=None, verbose=False): self._style['figwidth'] * fig_h / fig_w) if self._global_phase: self.plt_mod.text(_xl, _yt, 'Global Phase: %s' % pi_check(self._global_phase, - output='mpl')) + output='mpl')) if filename: self._figure.savefig(filename, dpi=self._style['dpi'], bbox_inches='tight', @@ -1084,6 +1085,7 @@ class HasMatplotlibWrapper: """Wrapper to lazily import matplotlib.""" has_matplotlib = False + # pylint: disable=unused-import def __bool__(self): if not self.has_matplotlib: try: diff --git a/qiskit/visualization/pulse/matplotlib.py b/qiskit/visualization/pulse/matplotlib.py index c4adbd2237c6..71d98c6d71ed 100644 --- a/qiskit/visualization/pulse/matplotlib.py +++ b/qiskit/visualization/pulse/matplotlib.py @@ -281,6 +281,9 @@ def draw(self, pulse: Waveform, Returns: matplotlib.figure.Figure: A matplotlib figure object of the pulse envelope. + + Raises: + ImportError: If matplotlib is not installed """ # If these self.style.dpi or self.style.figsize are None, they will # revert back to their default rcParam keys. @@ -340,6 +343,8 @@ def __init__(self, style: SchedStyle): Args: style: Style sheet for pulse schedule visualization. + Raises: + ImportError: If matplotlib is not installed """ if not HAS_MATPLOTLIB: raise ImportError("Matplotlib needs to be installed to use " @@ -347,7 +352,9 @@ def __init__(self, style: SchedStyle): "'pip install matplotlib'") from matplotlib import pyplot as plt + self.plt_mod = plt from matplotlib import gridspec + self.gridspec_mod = gridspec self.style = style or SchedStyle() def _build_channels(self, schedule: ScheduleComponent, @@ -517,9 +524,9 @@ def _draw_table(self, figure, h_waves = (figure.get_size_inches()[1] - h_table) # create subplots - gs = gridspec.GridSpec(2, 1, height_ratios=[h_table, h_waves], hspace=0) - tb = plt.subplot(gs[0]) - ax = plt.subplot(gs[1]) + gs = self.gridspec_mod.GridSpec(2, 1, height_ratios=[h_table, h_waves], hspace=0) + tb = self.plt_mod.subplot(gs[0]) + ax = self.plt_mod.subplot(gs[1]) # configure each cell tb.axis('off') @@ -803,7 +810,7 @@ def draw(self, schedule: ScheduleComponent, Raises: VisualizationError: When schedule cannot be drawn """ - figure = plt.figure(dpi=self.style.dpi, figsize=self.style.figsize) + figure = self.plt_mod.figure(dpi=self.style.dpi, figsize=self.style.figsize) if channels is None: channels = [] diff --git a/qiskit/visualization/pulse_visualization.py b/qiskit/visualization/pulse_visualization.py index eb2f76ecd0ee..33fa532280af 100644 --- a/qiskit/visualization/pulse_visualization.py +++ b/qiskit/visualization/pulse_visualization.py @@ -151,6 +151,9 @@ def pulse_drawer(data: Union[Waveform, Union[Schedule, Instruction]], if not HAS_MATPLOTLIB: raise ImportError('Must have Matplotlib installed.') + from matplotlib import get_backend + from matplotlib import pyplot as plt + if isinstance(data, Waveform): drawer = _matplotlib.WaveformDrawer(style=style) image = drawer.draw(data, dt=dt, interp_method=interp_method, scale=scale) @@ -170,7 +173,7 @@ def pulse_drawer(data: Union[Waveform, Union[Schedule, Instruction]], if get_backend() in ['module://ipykernel.pylab.backend_inline', 'nbAgg']: - _matplotlib.plt.close(image) + plt.close(image) if image and interactive: image.show() return image diff --git a/qiskit/visualization/state_visualization.py b/qiskit/visualization/state_visualization.py index add19835ed5d..030729fb5373 100644 --- a/qiskit/visualization/state_visualization.py +++ b/qiskit/visualization/state_visualization.py @@ -24,8 +24,7 @@ from scipy import linalg from qiskit.quantum_info.states import DensityMatrix from qiskit.utils.deprecation import deprecate_arguments -from .matplotlib import HAS_MATPLOTLIB - +from qiskit.visualization.matplotlib import HAS_MATPLOTLIB from qiskit.visualization.exceptions import VisualizationError from qiskit.visualization.utils import _bloch_multivector_data, _paulivec_data from qiskit.circuit.tools.pi_check import pi_check @@ -685,6 +684,7 @@ def plot_state_qsphere(state, figsize=None, ax=None, show_state_labels=True, import matplotlib.gridspec as gridspec from matplotlib import pyplot as plt from matplotlib.patches import Circle + from matplotlib import get_backend class Arrow3D(FancyArrowPatch): """Standard 3D arrow.""" @@ -891,6 +891,8 @@ def generate_facecolors(x, y, z, dx, dy, dz, color): color (array_like): sequence of valid color specifications, optional Returns: list: Shaded colors for bars. + Raises: + ImportError: If matplotlib is not installed """ if not HAS_MATPLOTLIB: raise ImportError('Must have Matplotlib installed. To install, run '